Repository: Dao-AILab/flash-attention Branch: main Commit: dd6f4a212a4f Files: 994 Total size: 5.1 MB Directory structure: gitextract_wgwkfssl/ ├── .github/ │ └── workflows/ │ ├── README.md │ ├── _build.yml │ ├── build.yml │ ├── pre-commit.yaml │ ├── publish-fa4.yml │ └── publish.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── AI/ │ ├── DEBUG_2CTA.md │ ├── RACECHECK_TMA_HAZARD.md │ ├── SM90_BLOCK_SIZE_TUNING.md │ ├── SM90_R2P_MASKING_SASS.md │ ├── VARLEN_PREPROCESS_TILE_BUG.md │ ├── racecheck_repro_1d_bulk.py │ └── racecheck_repro_1d_tensor.py ├── AUTHORS ├── CLAUDE.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── benchmarks/ │ ├── bench_sm90.py │ ├── benchmark_alibi.py │ ├── benchmark_attn.py │ ├── benchmark_causal.py │ ├── benchmark_flash_attention.py │ └── benchmark_gemm.py ├── csrc/ │ ├── flash_attn/ │ │ ├── flash_api.cpp │ │ └── src/ │ │ ├── alibi.h │ │ ├── block_info.h │ │ ├── dropout.h │ │ ├── flash.h │ │ ├── flash_bwd_hdim128_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_sm80.cu │ │ ├── flash_bwd_hdim32_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim32_bf16_sm80.cu │ │ ├── flash_bwd_hdim32_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim32_fp16_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_causal_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_causal_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_sm80.cu │ │ ├── flash_bwd_kernel.h │ │ ├── flash_bwd_launch_template.h │ │ ├── flash_bwd_preprocess_kernel.h │ │ ├── flash_fwd_hdim128_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_hdim32_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim32_bf16_sm80.cu │ │ ├── flash_fwd_hdim32_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim32_fp16_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_causal_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_causal_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm80.cu │ │ ├── flash_fwd_kernel.h │ │ ├── flash_fwd_launch_template.h │ │ ├── flash_fwd_split_hdim128_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim128_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim192_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim192_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim256_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim256_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim32_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim32_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim32_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim32_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim64_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim64_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_split_hdim96_bf16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_split_hdim96_fp16_causal_sm80.cu │ │ ├── flash_fwd_split_hdim96_fp16_sm80.cu │ │ ├── generate_kernels.py │ │ ├── hardware_info.h │ │ ├── kernel_traits.h │ │ ├── mask.h │ │ ├── namespace_config.h │ │ ├── philox.cuh │ │ ├── philox_unpack.cuh │ │ ├── rotary.h │ │ ├── softmax.h │ │ ├── static_switch.h │ │ └── utils.h │ ├── flash_attn_ck/ │ │ ├── flash_api.cpp │ │ ├── flash_common.cpp │ │ ├── flash_common.hpp │ │ ├── mha_bwd.cpp │ │ ├── mha_fwd.cpp │ │ ├── mha_fwd_kvcache.cpp │ │ ├── mha_varlen_bwd.cpp │ │ └── mha_varlen_fwd.cpp │ ├── fused_dense_lib/ │ │ ├── README.md │ │ ├── fused_dense.cpp │ │ ├── fused_dense_cuda.cu │ │ └── setup.py │ └── layer_norm/ │ ├── README.md │ ├── ln.h │ ├── ln_api.cpp │ ├── ln_bwd_1024.cu │ ├── ln_bwd_1280.cu │ ├── ln_bwd_1536.cu │ ├── ln_bwd_2048.cu │ ├── ln_bwd_256.cu │ ├── ln_bwd_2560.cu │ ├── ln_bwd_3072.cu │ ├── ln_bwd_4096.cu │ ├── ln_bwd_512.cu │ ├── ln_bwd_5120.cu │ ├── ln_bwd_6144.cu │ ├── ln_bwd_7168.cu │ ├── ln_bwd_768.cu │ ├── ln_bwd_8192.cu │ ├── ln_bwd_kernels.cuh │ ├── ln_fwd_1024.cu │ ├── ln_fwd_1280.cu │ ├── ln_fwd_1536.cu │ ├── ln_fwd_2048.cu │ ├── ln_fwd_256.cu │ ├── ln_fwd_2560.cu │ ├── ln_fwd_3072.cu │ ├── ln_fwd_4096.cu │ ├── ln_fwd_512.cu │ ├── ln_fwd_5120.cu │ ├── ln_fwd_6144.cu │ ├── ln_fwd_7168.cu │ ├── ln_fwd_768.cu │ ├── ln_fwd_8192.cu │ ├── ln_fwd_kernels.cuh │ ├── ln_kernel_traits.h │ ├── ln_parallel_bwd_1024.cu │ ├── ln_parallel_bwd_1280.cu │ ├── ln_parallel_bwd_1536.cu │ ├── ln_parallel_bwd_2048.cu │ ├── ln_parallel_bwd_256.cu │ ├── ln_parallel_bwd_2560.cu │ ├── ln_parallel_bwd_3072.cu │ ├── ln_parallel_bwd_4096.cu │ ├── ln_parallel_bwd_512.cu │ ├── ln_parallel_bwd_5120.cu │ ├── ln_parallel_bwd_6144.cu │ ├── ln_parallel_bwd_7168.cu │ ├── ln_parallel_bwd_768.cu │ ├── ln_parallel_bwd_8192.cu │ ├── ln_parallel_fwd_1024.cu │ ├── ln_parallel_fwd_1280.cu │ ├── ln_parallel_fwd_1536.cu │ ├── ln_parallel_fwd_2048.cu │ ├── ln_parallel_fwd_256.cu │ ├── ln_parallel_fwd_2560.cu │ ├── ln_parallel_fwd_3072.cu │ ├── ln_parallel_fwd_4096.cu │ ├── ln_parallel_fwd_512.cu │ ├── ln_parallel_fwd_5120.cu │ ├── ln_parallel_fwd_6144.cu │ ├── ln_parallel_fwd_7168.cu │ ├── ln_parallel_fwd_768.cu │ ├── ln_parallel_fwd_8192.cu │ ├── ln_parallel_residual_bwd_kernels.cuh │ ├── ln_parallel_residual_fwd_kernels.cuh │ ├── ln_utils.cuh │ ├── setup.py │ └── static_switch.h ├── examples/ │ └── inference/ │ └── README.md ├── flash_attn/ │ ├── __init__.py │ ├── bert_padding.py │ ├── cute/ │ │ ├── .flake8 │ │ ├── AUTHORS │ │ ├── LICENSE │ │ ├── MANIFEST.in │ │ ├── README.md │ │ ├── __init__.py │ │ ├── ampere_helpers.py │ │ ├── barrier.py │ │ ├── bench_utils.py │ │ ├── benchmark.py │ │ ├── blackwell_helpers.py │ │ ├── block_info.py │ │ ├── block_sparse_utils.py │ │ ├── block_sparsity.py │ │ ├── cache_utils.py │ │ ├── compute_block_sparsity.py │ │ ├── copy_utils.py │ │ ├── cute_dsl_ptxas.py │ │ ├── cute_dsl_utils.py │ │ ├── fa_logging.py │ │ ├── fast_math.py │ │ ├── flash_bwd.py │ │ ├── flash_bwd_postprocess.py │ │ ├── flash_bwd_preprocess.py │ │ ├── flash_bwd_sm100.py │ │ ├── flash_bwd_sm120.py │ │ ├── flash_bwd_sm90.py │ │ ├── flash_fwd.py │ │ ├── flash_fwd_combine.py │ │ ├── flash_fwd_sm100.py │ │ ├── flash_fwd_sm120.py │ │ ├── flash_fwd_sm90.py │ │ ├── interface.py │ │ ├── mask.py │ │ ├── mma_sm100_desc.py │ │ ├── named_barrier.py │ │ ├── pack_gqa.py │ │ ├── paged_kv.py │ │ ├── pipeline.py │ │ ├── pyproject.toml │ │ ├── seqlen_info.py │ │ ├── sm90_config_search.py │ │ ├── softmax.py │ │ ├── testing.py │ │ ├── tile_scheduler.py │ │ └── utils.py │ ├── flash_attn_interface.py │ ├── flash_attn_triton.py │ ├── flash_attn_triton_og.py │ ├── flash_blocksparse_attention.py │ ├── flash_blocksparse_attn_interface.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── patch_embed.py │ │ └── rotary.py │ ├── losses/ │ │ ├── __init__.py │ │ └── cross_entropy.py │ ├── models/ │ │ ├── __init__.py │ │ ├── baichuan.py │ │ ├── bert.py │ │ ├── bigcode.py │ │ ├── btlm.py │ │ ├── falcon.py │ │ ├── gpt.py │ │ ├── gpt_neox.py │ │ ├── gptj.py │ │ ├── llama.py │ │ ├── opt.py │ │ └── vit.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── block.py │ │ ├── embedding.py │ │ ├── mha.py │ │ └── mlp.py │ ├── ops/ │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── fused_dense.py │ │ ├── layer_norm.py │ │ ├── rms_norm.py │ │ └── triton/ │ │ ├── __init__.py │ │ ├── cross_entropy.py │ │ ├── k_activations.py │ │ ├── layer_norm.py │ │ ├── linear.py │ │ ├── mlp.py │ │ └── rotary.py │ ├── pyproject.toml │ └── utils/ │ ├── __init__.py │ ├── benchmark.py │ ├── distributed.py │ ├── generation.py │ ├── library.py │ ├── pretrained.py │ ├── testing.py │ └── torch.py ├── hopper/ │ ├── __init__.py │ ├── benchmark_attn.py │ ├── benchmark_flash_attention_fp8.py │ ├── benchmark_mla_decode.py │ ├── benchmark_split_kv.py │ ├── block.h │ ├── copy_sm90_bulk_reduce.hpp │ ├── cuda_check.h │ ├── epilogue_bwd.hpp │ ├── epilogue_fwd.hpp │ ├── flash.h │ ├── flash_api.cpp │ ├── flash_api_stable.cpp │ ├── flash_attn_interface.py │ ├── flash_bwd_kernel_sm80.h │ ├── flash_bwd_kernel_sm90.h │ ├── flash_bwd_launch_template.h │ ├── flash_bwd_postprocess_kernel.h │ ├── flash_bwd_preprocess_kernel.h │ ├── flash_fwd_combine.cu │ ├── flash_fwd_combine_kernel.h │ ├── flash_fwd_combine_launch_template.h │ ├── flash_fwd_kernel_sm80.h │ ├── flash_fwd_kernel_sm90.h │ ├── flash_fwd_launch_template.h │ ├── flash_prepare_scheduler.cu │ ├── generate_kernels.py │ ├── heuristics.h │ ├── instantiations/ │ │ ├── flash_bwd_hdim128_bf16_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_sm90.cu │ │ ├── flash_bwd_hdim128_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim128_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim128_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim128_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim128_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim192_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim192_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim192_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim192_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim256_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim256_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim256_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim256_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim64_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim64_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim64_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim64_fp16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_softcap_sm80.cu │ │ ├── flash_bwd_hdim96_bf16_softcap_sm90.cu │ │ ├── flash_bwd_hdim96_bf16_softcapall_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_softcap_sm80.cu │ │ ├── flash_bwd_hdim96_fp16_softcap_sm90.cu │ │ ├── flash_bwd_hdim96_fp16_softcapall_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm100.cu │ │ ├── flash_fwd_hdim128_bf16_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim128_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim128_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_128_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim192_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim192_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim256_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim256_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_256_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim64_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim64_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_bf16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdim96_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_softcapall_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm80.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdim96_fp16_split_softcapall_sm80.cu │ │ ├── flash_fwd_hdimall_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_split_sm90.cu │ │ ├── flash_fwd_hdimall_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdimall_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_split_sm90.cu │ │ ├── flash_fwd_hdimall_fp16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_split_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_softcap_sm90.cu │ │ ├── flash_fwd_hdimdiff_fp16_split_sm90.cu │ │ └── flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu │ ├── mainloop_bwd_sm80.hpp │ ├── mainloop_bwd_sm90_tma_gmma_ws.hpp │ ├── mainloop_fwd_sm80.hpp │ ├── mainloop_fwd_sm90_tma_gmma_ws.hpp │ ├── mask.h │ ├── named_barrier.hpp │ ├── pack_gqa.h │ ├── padding.py │ ├── paged_kv.h │ ├── rotary.h │ ├── seqlen.h │ ├── setup.py │ ├── sm90_pipeline_no_cluster.hpp │ ├── softmax.h │ ├── static_switch.h │ ├── test_attn_kvcache.py │ ├── test_flash_attn.py │ ├── test_flash_attn_bwd_determinism.py │ ├── test_flash_attn_triton_amd.py │ ├── test_kvcache.py │ ├── test_torch_compile_and_export.py │ ├── test_util.py │ ├── tile_scheduler.hpp │ ├── tile_size.h │ └── utils.h ├── setup.py ├── tests/ │ ├── cute/ │ │ ├── benchmark_block_sparsity.py │ │ ├── benchmark_mask_mod.py │ │ ├── conftest.py │ │ ├── mask_mod_definitions.py │ │ ├── score_mod_definitions.py │ │ ├── test_block_sparsity.py │ │ ├── test_flash_attn.py │ │ ├── test_flash_attn_combine.py │ │ ├── test_flash_attn_fast.py │ │ ├── test_flash_attn_race_condition.py │ │ ├── test_flash_attn_varlen.py │ │ ├── test_mask_mod.py │ │ ├── test_score_mod.py │ │ ├── test_score_mod_varlen.py │ │ └── test_utils.py │ ├── layers/ │ │ └── test_rotary.py │ ├── losses/ │ │ ├── test_cross_entropy.py │ │ └── test_cross_entropy_parallel.py │ ├── models/ │ │ ├── test_baichuan.py │ │ ├── test_bert.py │ │ ├── test_bigcode.py │ │ ├── test_btlm.py │ │ ├── test_falcon.py │ │ ├── test_gpt.py │ │ ├── test_gpt_generation_parallel.py │ │ ├── test_gpt_neox.py │ │ ├── test_gpt_parallel.py │ │ ├── test_gptj.py │ │ ├── test_llama.py │ │ ├── test_opt.py │ │ └── test_vit.py │ ├── modules/ │ │ ├── test_block_parallel.py │ │ ├── test_embedding_parallel.py │ │ ├── test_mha_parallel.py │ │ └── test_mlp_parallel.py │ ├── ops/ │ │ ├── test_dropout_layer_norm.py │ │ ├── test_fused_dense.py │ │ ├── test_fused_dense_parallel.py │ │ └── triton/ │ │ └── test_layer_norm.py │ ├── pyproject.toml │ ├── test_flash_attn.py │ ├── test_flash_attn_ck.py │ ├── test_flash_attn_triton_amd.py │ ├── test_rotary.py │ └── test_util.py ├── tools/ │ └── sass_diff.py ├── training/ │ ├── Dockerfile │ ├── README.md │ ├── configs/ │ │ ├── callbacks/ │ │ │ ├── causality-monitor.yaml │ │ │ ├── default.yaml │ │ │ ├── ema.yaml │ │ │ ├── flop-count.yaml │ │ │ ├── gpu-monitor.yaml │ │ │ ├── model-summary.yaml │ │ │ ├── none.yaml │ │ │ ├── norm-monitor.yaml │ │ │ ├── params-log.yaml │ │ │ └── wandb.yaml │ │ ├── config.yaml │ │ ├── datamodule/ │ │ │ ├── openwebtext.yaml │ │ │ └── thepile.yaml │ │ ├── experiment/ │ │ │ ├── owt/ │ │ │ │ ├── base.yaml │ │ │ │ ├── gpt2l-flash.yaml │ │ │ │ ├── gpt2l-hf.yaml │ │ │ │ ├── gpt2l.yaml │ │ │ │ ├── gpt2m-flash.yaml │ │ │ │ ├── gpt2m-hf.yaml │ │ │ │ ├── gpt2m.yaml │ │ │ │ ├── gpt2s-flash.yaml │ │ │ │ ├── gpt2s-hf.yaml │ │ │ │ ├── gpt2s.yaml │ │ │ │ ├── gpt2xl-flash.yaml │ │ │ │ ├── gpt2xl-hf.yaml │ │ │ │ └── gpt2xl.yaml │ │ │ └── pile/ │ │ │ ├── base.yaml │ │ │ ├── gpt3-2.7B-flash-8k.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary-8k.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128-rotary.yaml │ │ │ ├── gpt3-2.7B-flash-hdim128.yaml │ │ │ ├── gpt3-2.7B-flash-rotary-8k.yaml │ │ │ ├── gpt3-2.7B-flash-rotary.yaml │ │ │ ├── gpt3-2.7B-flash.yaml │ │ │ ├── gpt3-2.7B-hf-hdim128.yaml │ │ │ ├── gpt3-2.7B-hf.yaml │ │ │ ├── gpt3l-flash-8k.yaml │ │ │ ├── gpt3l-flash-rotary-30B.yaml │ │ │ ├── gpt3l-flash-rotary-8k.yaml │ │ │ ├── gpt3l-flash-rotary.yaml │ │ │ ├── gpt3l-flash.yaml │ │ │ ├── gpt3l-hf.yaml │ │ │ ├── gpt3m-flash-8k.yaml │ │ │ ├── gpt3m-flash-rotary-30B.yaml │ │ │ ├── gpt3m-flash-rotary-8k.yaml │ │ │ ├── gpt3m-flash-rotary.yaml │ │ │ ├── gpt3m-flash.yaml │ │ │ ├── gpt3m-hf.yaml │ │ │ ├── gpt3s-flash-8k.yaml │ │ │ ├── gpt3s-flash-rotary-30B.yaml │ │ │ ├── gpt3s-flash-rotary-8k.yaml │ │ │ ├── gpt3s-flash-rotary.yaml │ │ │ ├── gpt3s-flash.yaml │ │ │ ├── gpt3s-hf.yaml │ │ │ ├── gpt3xl-flash-8k.yaml │ │ │ ├── gpt3xl-flash-rotary-60B.yaml │ │ │ ├── gpt3xl-flash-rotary-8k.yaml │ │ │ ├── gpt3xl-flash-rotary.yaml │ │ │ ├── gpt3xl-flash.yaml │ │ │ └── gpt3xl-hf.yaml │ │ ├── logger/ │ │ │ ├── comet.yaml │ │ │ ├── csv.yaml │ │ │ ├── many_loggers.yaml │ │ │ ├── mlflow.yaml │ │ │ ├── neptune.yaml │ │ │ ├── tensorboard.yaml │ │ │ └── wandb.yaml │ │ ├── metrics/ │ │ │ ├── acc.yaml │ │ │ ├── acc_ignore_index.yaml │ │ │ ├── acctop5.yaml │ │ │ ├── mse.yaml │ │ │ ├── num-tokens.yaml │ │ │ └── perplexity.yaml │ │ ├── mode/ │ │ │ ├── debug.yaml │ │ │ ├── default.yaml │ │ │ ├── exp.yaml │ │ │ ├── profile.yaml │ │ │ └── smoke.yaml │ │ ├── model/ │ │ │ ├── gpt2-hf.yaml │ │ │ ├── gpt2.yaml │ │ │ └── gpt2model/ │ │ │ ├── gpt2-large.yaml │ │ │ ├── gpt2-medium.yaml │ │ │ ├── gpt2-small.yaml │ │ │ └── gpt2-xlarge.yaml │ │ ├── optimizer/ │ │ │ ├── adam.yaml │ │ │ ├── adamw-apex-distributed.yaml │ │ │ ├── adamw-apex-zero.yaml │ │ │ ├── adamw-apex.yaml │ │ │ ├── adamw-zero.yaml │ │ │ ├── adamw.yaml │ │ │ ├── fusedlamb-ds.yaml │ │ │ ├── fusedlamb.yaml │ │ │ └── sgd.yaml │ │ ├── scheduler/ │ │ │ ├── cosine-warmup-timm.yaml │ │ │ ├── cosine-warmup.yaml │ │ │ ├── invsqrt.yaml │ │ │ ├── linear-warmup.yaml │ │ │ ├── multi-step.yaml │ │ │ ├── plateau.yaml │ │ │ ├── poly-warmup.yaml │ │ │ └── step.yaml │ │ ├── task/ │ │ │ └── sequence-model.yaml │ │ └── trainer/ │ │ ├── all_params.yaml │ │ ├── ddp.yaml │ │ ├── debug.yaml │ │ └── default.yaml │ ├── run.py │ ├── src/ │ │ ├── callbacks/ │ │ │ ├── __init__.py │ │ │ ├── causality_monitor.py │ │ │ ├── ema.py │ │ │ ├── flop_count.py │ │ │ ├── gpu_affinity.py │ │ │ ├── loss_scale_monitor.py │ │ │ ├── model_checkpoint.py │ │ │ ├── norm_monitor.py │ │ │ ├── params_log.py │ │ │ ├── speed_monitor.py │ │ │ └── wandb_callbacks.py │ │ ├── datamodules/ │ │ │ ├── datasets/ │ │ │ │ ├── detokenizer.py │ │ │ │ └── lm_dataset.py │ │ │ ├── fault_tolerant_sampler.py │ │ │ ├── imagenet.py │ │ │ ├── language_modeling_hf.py │ │ │ └── timm_mixup.py │ │ ├── distributed/ │ │ │ └── ddp_comm_hooks.py │ │ ├── eval.py │ │ ├── metrics/ │ │ │ ├── accuracy.py │ │ │ ├── num_tokens.py │ │ │ └── perplexity.py │ │ ├── models/ │ │ │ └── modules/ │ │ │ └── seq_common.py │ │ ├── optim/ │ │ │ ├── param_grouping.py │ │ │ └── timm_lr_scheduler.py │ │ ├── tasks/ │ │ │ └── seq.py │ │ ├── train.py │ │ └── utils/ │ │ ├── checkpoint.py │ │ ├── ddp_zero1.py │ │ ├── ddp_zero2.py │ │ ├── distributed.py │ │ ├── ema.py │ │ ├── flops.py │ │ ├── gpu_affinity.py │ │ └── utils.py │ └── tests/ │ └── datamodules/ │ └── test_language_modeling_hf.py └── usage.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/README.md ================================================ # GitHub Workflow Tagging Flow This repository uses separate tag lanes so FA2 and FA4 publishing do not collide. ## Release lanes | Tag pattern | Workflow | Package target | Version source | | --- | --- | --- | --- | | `v*` | `.github/workflows/publish.yml` | Root package (`flash-attn`) | Root package version metadata | | `fa4-v*` | `.github/workflows/publish-fa4.yml` | `flash_attn/cute` package (`flash-attn-4`) | `setuptools-scm` with `fa4-v*` tags | ## How to publish ### FA2 / root package lane 1. Create a tag matching `v*` (example: `v2.9.0`). 2. Push that tag. 3. `publish.yml` creates a release, builds wheel matrix artifacts, and publishes to PyPI. ### FA4 / CUTE package lane 1. Create a tag matching `fa4-v*` (example: `fa4-v0.1.0`). 2. Push that tag. 3. `publish-fa4.yml` builds from `flash_attn/cute`, creates a GitHub release, and uploads `flash-attn-4` to PyPI. ## Guardrails - Do not use `v*` tags for FA4 releases. - Do not use `fa4-v*` tags for FA2 releases. - Keep `flash_attn/cute/pyproject.toml` tag parsing in sync with the FA4 tag prefix. - The workflow filename (`publish-fa4.yml`) is part of the PyPI trusted publishing OIDC identity — do not rename without updating PyPI. ================================================ FILE: .github/workflows/_build.yml ================================================ name: ~Build wheel template on: workflow_call: inputs: runs-on: description: "The runner to use for the build" required: true type: string python-version: description: "The Python version to use for the build" required: true type: string cuda-version: description: "The CUDA version to use for the build" required: true type: string torch-version: description: "The PyTorch version to use for the build" required: true type: string cxx11_abi: description: "The C++11 ABI to use for the build" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Upload wheel to this release" required: false type: string defaults: run: shell: bash -x -e -u -o pipefail {0} jobs: build-wheel: runs-on: ${{ inputs.runs-on }} name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) steps: - name: Checkout uses: actions/checkout@v5 with: ref: ${{ inputs.release-version }} submodules: recursive - name: Set up Python uses: actions/setup-python@v5 with: python-version: ${{ inputs.python-version }} - name: Set CUDA and PyTorch versions run: | echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} # https://github.com/easimon/maximize-build-space/blob/master/action.yml # https://github.com/easimon/maximize-build-space/tree/test-report run: | sudo rm -rf /usr/share/dotnet sudo rm -rf /opt/ghc sudo rm -rf /opt/hostedtoolcache/CodeQL - name: Set up swap space if: runner.os == 'Linux' uses: pierotofy/set-swap-space@v1.0 with: swap-size-gb: 10 - name: Install CUDA ${{ inputs.cuda-version }} if: ${{ inputs.cuda-version != 'cpu' }} uses: Jimver/cuda-toolkit@v0.2.30 id: cuda-toolkit with: cuda: ${{ inputs.cuda-version }} linux-local-args: '["--toolkit"]' # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} method: "network" sub-packages: '["nvcc"]' - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} run: | pip install --upgrade pip # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 # Pick the highest available PyTorch wheel CUDA version that doesn't exceed system CUDA export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ available = { \ '2.6': [118, 124, 126], \ '2.7': [118, 126, 128], \ '2.8': [126, 128, 129], \ '2.9': [126, 128, 130], \ '2.10': [126, 128, 130], \ }[env['MATRIX_TORCH_VERSION']]; \ sys_cuda = int(env['MATRIX_CUDA_VERSION']); \ print(max(v for v in available if v <= sys_cuda))" \ ) # detect if we're on ARM if [ "$(uname -m)" = "aarch64" ] || [ "$(uname -m)" = "arm64" ]; then PLAT=linux_aarch64 else PLAT=manylinux_2_27_x86_64.manylinux_2_28_x86_64 fi echo "PLAT=$PLAT" >> $GITHUB_ENV if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} # Can't use --no-deps because we need cudnn etc. # Hard-coding this version of pytorch-triton for torch 2.9.0.dev20250904 pip install jinja2 TRITON_URL=https://download.pytorch.org/whl/nightly/pytorch_triton-3.4.0%2Bgitf7888497-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-${PLAT}.whl TORCH_URL=https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-manylinux_2_28_$(uname -m).whl pip install --no-cache-dir --pre "${TRITON_URL}" pip install --no-cache-dir --pre "${TORCH_URL}" else pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi nvcc --version python --version python -c "import torch; print('PyTorch:', torch.__version__)" python -c "import torch; print('CUDA:', torch.version.cuda)" python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" - name: Restore build cache uses: actions/cache/restore@v4 with: path: build.tar key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} restore-keys: | build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- - name: Unpack build cache run: | echo ::group::Adjust timestamps sudo find / -exec touch -t 197001010000 {} + || true echo ::endgroup:: if [ -f build.tar ]; then find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + tar -xpvf build.tar -C . else echo "No build.tar found, skipping" fi ls -al ./ ls -al build/ || true ls -al csrc/ || true - name: Build wheel id: build_wheel run: | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 # However this still fails so I'm using a newer version of setuptools pip install setuptools==75.8.0 pip install ninja packaging wheel export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Limit MAX_JOBS otherwise the github runner goes OOM # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] || [ "$MATRIX_CUDA_VERSION" == "130" ] && echo 1 || echo 2) export NVCC_THREADS=2 export FLASH_ATTENTION_FORCE_BUILD="TRUE" export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }} # 5h timeout since GH allows max 6h and we want some buffer EXIT_CODE=0 timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? if [ $EXIT_CODE -eq 0 ]; then tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV fi # Store exit code in GitHub env for later steps echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" # Do not fail the job if timeout killed the build exit $EXIT_CODE - name: Log build logs after timeout if: always() && steps.build_wheel.outputs.build_exit_code == 124 run: | ls -al ./ tar -cvf build.tar . --atime-preserve=replace - name: Save build cache timeout if: always() && steps.build_wheel.outputs.build_exit_code == 124 uses: actions/cache/save@v4 with: key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} path: build.tar - name: Log Built Wheels run: | ls dist - name: Get Release with tag id: get_current_release uses: joutvhu/get-release@v1 with: tag_name: ${{ inputs.release-version }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload Release Asset id: upload_release_asset if: inputs.upload-to-release uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.get_current_release.outputs.upload_url }} asset_path: ./dist/${{env.wheel_name}} asset_name: ${{env.wheel_name}} asset_content_type: application/* ================================================ FILE: .github/workflows/build.yml ================================================ name: Build wheels on: workflow_dispatch: inputs: runs-on: description: "The runner to use for the build" required: true type: string default: ubuntu-22.04 python-version: description: "The Python version to use for the build" required: true type: string cuda-version: description: "The CUDA version to use for the build" required: true type: string torch-version: description: "The PyTorch version to use for the build" required: true type: string cxx11_abi: description: "Enable torch flag C++11 ABI (TRUE/FALSE)" required: true type: string upload-to-release: description: "Upload wheel to this release" required: false type: boolean default: false release-version: description: "Upload wheel to this release" required: false type: string jobs: build-wheels: uses: ./.github/workflows/_build.yml with: runs-on: ${{ inputs.runs-on }} python-version: ${{ inputs.python-version }} cuda-version: ${{ inputs.cuda-version }} torch-version: ${{ inputs.torch-version }} cxx11_abi: ${{ inputs.cxx11_abi }} upload-to-release: ${{ inputs.upload-to-release }} release-version: ${{ inputs.release-version }} ================================================ FILE: .github/workflows/pre-commit.yaml ================================================ name: Lint on: pull_request: paths: - 'flash_attn/cute/flash_bwd_sm90.py' - 'flash_attn/cute/flash_bwd_preprocess.py' - 'flash_attn/cute/flash_bwd_postprocess.py' - 'flash_attn/cute/softmax.py' - '.pre-commit-config.yaml' push: branches: - main paths: - 'flash_attn/cute/flash_bwd_sm90.py' - 'flash_attn/cute/flash_bwd_preprocess.py' - 'flash_attn/cute/flash_bwd_postprocess.py' - 'flash_attn/cute/softmax.py' - '.pre-commit-config.yaml' jobs: pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - name: Set up Python uses: actions/setup-python@v6 with: python-version: '3.11' - name: Run pre-commit uses: pre-commit/action@v3.0.1 ================================================ FILE: .github/workflows/publish-fa4.yml ================================================ name: Publish flash-attn-4 to PyPI on: push: tags: - 'fa4-v*' permissions: contents: write jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - uses: actions/setup-python@v5 with: python-version: '3.12' - name: Install build dependencies run: pip install build twine - name: Build package run: python -m build working-directory: flash_attn/cute - name: Check package metadata run: twine check dist/* working-directory: flash_attn/cute - name: Store distribution packages uses: actions/upload-artifact@v4 with: name: python-package-distributions path: flash_attn/cute/dist/ github-release: needs: build runs-on: ubuntu-latest steps: - name: Download distribution packages uses: actions/download-artifact@v4 with: name: python-package-distributions path: dist/ - name: Create GitHub Release uses: softprops/action-gh-release@v2 with: files: dist/* generate_release_notes: true publish-to-pypi: needs: build runs-on: ubuntu-latest environment: name: pypi url: https://pypi.org/p/flash-attn-4 permissions: id-token: write steps: - name: Download distribution packages uses: actions/download-artifact@v4 with: name: python-package-distributions path: dist/ - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 ================================================ FILE: .github/workflows/publish.yml ================================================ # This workflow will: # - Create a new Github release # - Build wheels for supported architectures # - Deploy the wheels to the Github release # - Release the static code to PyPi # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries name: Build wheels and deploy on: create: tags: - v* jobs: setup_release: name: Create Release runs-on: ubuntu-latest outputs: release-version: ${{ steps.extract_branch.outputs.branch }} steps: - name: Get the tag version id: extract_branch run: echo "branch=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT shell: bash - name: Create Release env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: gh release create ${{ steps.extract_branch.outputs.branch }} --repo $GITHUB_REPOSITORY --title ${{ steps.extract_branch.outputs.branch }} --generate-notes shell: bash build_wheels: name: Build Wheel needs: setup_release strategy: fail-fast: false matrix: # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-22.04, ubuntu-22.04-arm] python-version: ["3.10", "3.11", "3.12", "3.13"] torch-version: ["2.6.0", "2.7.1", "2.8.0", "2.9.1", "2.10.0"] cuda-version: ["12.9.1", "13.0.1"] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ["FALSE", "TRUE"] exclude: # CUDA 13.0 is only supported by PyTorch 2.9+ - torch-version: "2.6.0" cuda-version: "13.0.1" - torch-version: "2.7.1" cuda-version: "13.0.1" - torch-version: "2.8.0" cuda-version: "13.0.1" # No aarch64 PyTorch wheels for 2.6.0 - torch-version: "2.6.0" os: ubuntu-22.04-arm # PyTorch 2.7+ pip wheels use CXX11_ABI=1 by default, no need for FALSE - torch-version: "2.7.1" cxx11_abi: "FALSE" - torch-version: "2.8.0" cxx11_abi: "FALSE" - torch-version: "2.9.1" cxx11_abi: "FALSE" - torch-version: "2.10.0" cxx11_abi: "FALSE" uses: ./.github/workflows/_build.yml with: runs-on: ${{ matrix.os }} python-version: ${{ matrix.python-version }} cuda-version: ${{ matrix.cuda-version }} torch-version: ${{ matrix.torch-version }} cxx11_abi: ${{ matrix.cxx11_abi }} release-version: ${{ needs.setup_release.outputs.release-version }} upload-to-release: true publish_package: name: Publish package needs: [build_wheels] runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies run: | pip install ninja packaging wheel twine # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu - name: Build core package env: FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE" run: | python setup.py sdist --dist-dir=dist - name: Deploy env: TWINE_USERNAME: "__token__" TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | python -m twine upload dist/* ================================================ FILE: .gitignore ================================================ *.ncu-rep *.sass *.ptx *.cubin *.plk .DS_store .vscode # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] # C extensions *.so # Distribution / packaging bin/ build/ develop-eggs/ dist/ eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg .eggs/ # IDE-related .idea/ .vscode/ # Dev venv # compile-time generated file flash_attn_config.py ================================================ FILE: .gitmodules ================================================ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git [submodule "csrc/composable_kernel"] path = csrc/composable_kernel url = https://github.com/ROCm/composable_kernel.git branch = amd-master [submodule "third_party/aiter"] path = third_party/aiter url = https://github.com/ROCm/aiter.git ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.13 hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] files: ^flash_attn/cute/.*\.py$ exclude: &cute_exclude | (?x)^flash_attn/cute/( flash_bwd| flash_fwd| flash_fwd_sm100| interface| )\.py$ - id: ruff-format files: ^flash_attn/cute/.*\.py$ exclude: *cute_exclude ================================================ FILE: AI/DEBUG_2CTA.md ================================================ # Debugging GPU Kernel Hangs (Deadlocks) in CUTLASS DSL / 2CTA Kernels ## General Approach to Debugging Kernel Hangs ### Step 1: Build a minimal repro Strip the test case down to the smallest input that triggers the hang: - batch=1, nheads=1, smallest seqlen that hangs - Single config, no loops, no benchmarking - Add a timeout or run with `compute-sanitizer` so you can distinguish a hang from slow execution ### Step 2: Add printf to locate the hang GPU `printf` (`cute.printf`) is the primary tool. The goal is binary search: narrow down which warp and which operation is blocked. **Printf guards** — avoid print storms: ```python # One thread per warp: if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("...") # One thread per CTA (elect_one is a context manager, not a bool): with cute.arch.elect_one(): cute.printf("...") # One specific thread: if tidx == 0: cute.printf("...") ``` **Strategy — coarse to fine:** 1. First, print at the entry/exit of each warp's main function (load, mma, softmax, correction). This tells you which warp is stuck. 2. Then add prints before/after each pipeline wait (`consumer_wait`, `producer_acquire`). This tells you which barrier is stuck. 3. Then print the barrier index, phase, and stage to understand the pipeline state. **What to print:** - CTA index (`cute.arch.block_idx()[0]`) — critical for multi-CTA debugging - Pipeline stage index and phase - Loop iteration count - Whether a `try_wait` succeeds or fails (use `try_wait_token` parameter) ### Step 3: Identify the deadlock chain A hang is always a cycle. Typical chain in a pipelined kernel: ``` MMA waiting for K from load (pipeline_kv full barrier) -> Load finished but stuck in producer_tail (waiting for MMA to release empty barrier) -> MMA can't release because it's waiting for K ``` Once you see which barrier is stuck, trace backwards: who is supposed to signal it, and why haven't they? ### Step 4: Vary the problem size systematically Test with different sequence lengths / block counts to find the pattern: | seqlen | n_blocks | Result | |--------|----------|--------| | 128 | 1 | ? | | 256 | 2 | ? | | 384 | 3 | ? | | 512 | 4 | ? | If the hang correlates with the number of visits to a pipeline stage (e.g., works for n_blocks <= kv_stages but fails when stages wrap around), the problem is likely in barrier tx_count or phase tracking. ### Step 5: Check barrier byte counts (tx_count) For TMA-based pipelines, `arrive_and_expect_tx` sets the expected transaction byte count on an mbarrier. If the expected count doesn't match the actual bytes arriving, the barrier either: - Fires too early (expected < actual) — causes data races - Never fires (expected > actual) — causes hangs In **2CTA / cluster mode**, both CTAs' TMAs signal the **same** cluster-level mbarrier. If each CTA's TMA contributes N bytes, the barrier receives 2N bytes total. The tx_count must be `N * cta_group_size`, not just `N`. **All TMA pipelines need doubling** — Q, K, and V. Even though each CTA loads a different M-tile for Q, both CTAs' TMA operations still signal the same cluster-level barrier, so the expected byte count must account for both. ### Step 6: Check phase / parity tracking `mbarrier_try_wait_parity` uses a single parity bit (0 or 1). If your pipeline state tracks phase as a monotonically increasing counter (0, 1, 2, 3, ...), you need `phase % 2` before passing it to the barrier wait. Without this, phase=2 looks like phase=0 to the hardware, which can cause waits on already-completed barriers or misses on pending ones. ### Step 7: Beware compiler-as-bug-source If the kernel works WITH printf but hangs WITHOUT it, the printf is acting as a **compiler barrier**. The MLIR/LLVM backend cannot optimize through an opaque function call like printf, which prevents harmful instruction reordering. Signs this is happening: - A single `cute.printf("\n")` in the right function fixes the hang - PTX fences (`fence_view_async_shared`, `fence_acq_rel_cluster`, `sync_warp`, `fence_proxy`) do NOT fix it — these affect hardware memory ordering, not compiler scheduling - The fix is location-sensitive (printf in one function fixes it, in another doesn't) Possible workarounds: - `@dsl_user_op` decorator on pipeline methods to make them opaque to the compiler - `asm volatile` barriers (if available in the DSL) - Compare generated PTX/SASS with and without printf to identify what the compiler is reordering - File a bug against the CUTLASS DSL / MLIR pipeline --- ## 2CTA-Specific Pitfalls ### tcgen05.commit with empty commit groups `tcgen05.commit(mbar, mask, cta_group::2)` is supposed to signal an mbarrier after all pending MMA operations complete. But if there are **no pending operations** (empty commit group), the signal only reaches the local CTA's barrier, not the remote CTA's. Fix: use explicit `mbarrier_arrive(barrier, dst_cta_rank)` to both CTAs. ### producer_tail deadlock The default `producer_tail` (inherited from sm90 pipelines) drains the pipeline by calling `producer_acquire` in a loop. In 2CTA mode this deadlocks because the consumer (MMA warp) may have already exited without releasing all stages. Fix: make `producer_tail` a no-op for 2CTA. ### Tile scheduler must account for cluster shape Both CTAs in a cluster must get the **same** tile coordinate. Raw `blockIdx.x` assigns consecutive values to CTAs in the same cluster. Fix: divide `blockIdx.x` by `cluster_shape_m`. ### Cross-CTA vs per-CTA pipelines Pipelines where CTA 1's threads remotely arrive on CTA 0's barriers need cluster-sized cooperative group counts. Pipelines that are purely local to each CTA keep per-CTA counts. ### Softmax masking offset Causal mask row positions must account for the CTA's position within the cluster. Multiply `m_block` by `cta_group_size` when computing mask coordinates. ================================================ FILE: AI/RACECHECK_TMA_HAZARD.md ================================================ # compute-sanitizer racecheck hazard with `cp.async.bulk` ## Summary `compute-sanitizer --tool=racecheck` reports false-positive shared-memory race hazards when `cp.async.bulk` (raw-address TMA) is used in a cross-warp producer/consumer pipeline inside a dynamic loop. The same pattern with `cp.async.bulk.tensor` (descriptor-based TMA) reports **zero hazards**. The fix for the flash backward kernel is to switch the LSE/dPsum copies from `CopyBulkG2SOp` (`cp.async.bulk`) to `CopyBulkTensorTileG2SOp` (`cp.async.bulk.tensor`) using `cpasync.make_tiled_tma_atom`. ## Affected code `flash_attn/cute/flash_bwd_sm100.py` — the SM100 backward attention kernel. Only **LSE** and **dPsum** buffers are affected because they are the only TMA-loaded buffers consumed by thread-level shared memory reads (`lds`). Q/K/V/dO are consumed by UMMA hardware instructions, which do not generate thread-level `lds` and therefore never trigger racecheck. ## Root cause racecheck instruments every shared memory access and checks for conflicting accesses lacking a recognized happens-before relationship. **`cp.async.bulk` (raw address):** the sanitizer attributes the smem write to the issuing thread (thread 0 of warp 0 via `elect_one`). When warp 1 issues `ld.shared.b32` from the same addresses, the sanitizer searches for a happens-before edge. The only sync is `mbarrier.try_wait.parity` on warp 1 paired with `mbarrier::complete_tx::bytes` completion from the hardware. The sanitizer does not model this as happens-before across warps in a dynamic loop. **`cp.async.bulk.tensor` (TMA descriptor):** the TMA engine is a separate hardware unit. The sanitizer does not attribute the smem write to any thread. No writer thread means no hazard pair, so no race is reported. ### Instruction comparison | Variant | PTX | racecheck | |---------|-----|-----------| | Raw (cta scope) | `cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes` | **hazard** | | Raw (cluster scope) | `cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes` | **hazard** | | Descriptor 1D | `cp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean | | Descriptor 2D | `cp.async.bulk.tensor.2d.shared::cta.global.tile.mbarrier::complete_tx::bytes` | clean | ### `--racecheck-memcpy-async=no` does not help This flag controls the older `cp.async` (sm80) instruction family, not `cp.async.bulk`. The hazard persists with `--racecheck-memcpy-async=no`. ## Proof that it is a false positive 1. **Data correctness** — all variants produce bit-identical results. 2. **Single-warp test** — one warp does both TMA write and thread read in the same loop; racecheck reports zero hazards with the same mbarrier sync. 3. **Unrolled loop** — fully unrolling (`unroll_full=True`) reports zero hazards; racecheck tracks mbarrier within straight-line code but not across a dynamic branch back-edge between warps. 4. **Named barrier** — adding `bar.sync` per iteration between producer and consumer warps eliminates the hazard; the sync is correct, racecheck just needs a primitive it recognizes. 5. **Descriptor TMA** — switching to `cp.async.bulk.tensor` with identical pipeline code eliminates the hazard; the mbarrier protocol is correct. ## Minimal reproducers ### `AI/` (preferred, cleaner) | File | Copy instruction | Result | |------|-----------------|--------| | `racecheck_repro_1d_bulk.py` | `cp.async.bulk` (raw address) | **1 error** | | `racecheck_repro_1d_tensor.py` | `cp.async.bulk.tensor.1d` (TMA descriptor) | **0 hazards** | Both are ~75-line self-contained kernels: 2 warps, 4 blocks, 2-stage double buffering with `PipelineTmaAsync`. Identical pipeline protocol — only the copy instruction differs. ```bash python AI/racecheck_repro_1d_bulk.py # correctness CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py # 1 error compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py # 0 hazards ``` ### `benchmarks/` (earlier, more variants) | File | What it tests | Result | |------|--------------|--------| | `racecheck_false_positive_repro.py` | `cp.async.bulk` + mbarrier in cross-warp loop | 1 error | | `racecheck_1d_raw_ptx.py` | Inline PTX `cp.async.bulk.shared::cta.global` | 1 error | | `racecheck_tma2d_repro.py` | `cp.async.bulk.tensor.2d` via `make_tiled_tma_atom` | 0 hazards | | `racecheck_tma1d_descriptor.py` | `cp.async.bulk.tensor.1d` via `make_tiled_tma_atom` | 0 hazards | ## PTX-level analysis Dumped PTX for both `AI/` reproducers (`CUTE_DSL_KEEP_PTX=1`). The generated code is byte-for-byte identical except for the single copy instruction: ``` # racecheck_repro_1d_bulk.py (HAZARD) cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%r42], [%rd12], %r43, [%r6+-16]; # racecheck_repro_1d_tensor.py (CLEAN) cp.async.bulk.tensor.1d.shared::cta.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint [%r43], [%rd1, {%r71}], [%r6+-16], %rd8; ``` All mbarrier operations (init, `fence.mbarrier_init.release.cluster`, `arrive.expect_tx`, `try_wait.parity`, `arrive.release`, `fence.proxy.async.shared::cta`, `bar.warp.sync`) are identical. ### racecheck error output ``` Error: Race reported between Write access at ...+0x430 in racecheck_repro_1d_bulk.py:46 and Read access at ...+0x770 in racecheck_repro_1d_bulk.py:55 [248 hazards] and Read access at ...+0x7a0 in racecheck_repro_1d_bulk.py:55 [248 hazards] and Read access at ...+0x7d0 in racecheck_repro_1d_bulk.py:55 [248 hazards] and Read access at ...+0x800 in racecheck_repro_1d_bulk.py:55 [248 hazards] ``` - **Write** (0x430) = line 46: `cute.copy(atom, src, s, mbar_ptr=...)` — the `cp.async.bulk` instruction - **Read** (0x770–0x800) = line 55: `dst[...] = s[...]` — four `ld.shared.b32` in the consumer warp ## Fix Change `copy_stats` in the load function from: ```python copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) copy_stats = partial(cute.copy, copy_atom_stats) ``` to a descriptor-based TMA using `cpasync.make_tiled_tma_atom` with `CopyBulkTensorTileG2SOp`. This generates `cp.async.bulk.tensor.1d` instead of `cp.async.bulk`, which racecheck does not instrument. The pipeline protocol (mbarrier init, arrive_expect_tx, try_wait_parity, consumer_release) remains identical. ## Backup `flash_attn/cute/flash_bwd_sm100_gmem_fix.py` contains a working but slower fix where compute warps read LSE/dPsum directly from global memory, bypassing the TMA smem pipeline entirely. ## Investigation timeline 1. Observed 2 racecheck errors on LSE and dPsum in `flash_bwd_sm100.py`. Q/K/V/dO clean. 2. Noticed Q/K/V/dO use UMMA consumers (no thread `lds`) while LSE/dPsum use thread-level `autovec_copy` from smem — explains why only LSE/dPsum trigger. 3. Built minimal 2-warp pipeline kernel reproducing the hazard. 4. Single-warp version clean — same mbarrier, same addresses. 5. Fully-unrolled version clean — racecheck tracks mbarrier within straight-line code. 6. `bar.sync` per iteration fixes it — racecheck needs a sync it recognizes across the loop back-edge. 7. `cp.async.bulk.tensor.2d` clean — different instruction, same pipeline. 8. `cp.async.bulk.tensor.1d` clean — issue is raw vs descriptor, not dimensionality. 9. Raw inline PTX `cp.async.bulk.shared::cta.global` also triggers — not a CuTe DSL abstraction issue. 10. Dumped PTX for both `AI/` reproducers — confirmed byte-identical code except for the copy instruction. Sanitizer attributes smem write to issuing thread for `cp.async.bulk` but not for `cp.async.bulk.tensor`. 11. Confirmed `--racecheck-memcpy-async=no` does not suppress the hazard — flag targets older `cp.async`, not `cp.async.bulk`. ================================================ FILE: AI/SM90_BLOCK_SIZE_TUNING.md ================================================ # SM90 Block Size Tuning Guide How to choose tile sizes and MMA configurations for FlashAttention on Hopper (SM90). ## Tool Use `flash_attn/cute/sm90_config_search.py` to enumerate feasible configs: ```bash # Both fwd and bwd python flash_attn/cute/sm90_config_search.py --headdim 128 # Forward only python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128 # Backward only, custom tile choices python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-m 64,80 --tile-n 64,96 ``` ## Hardware Constraints (H100) - **SMEM**: 228 KB total. We reserve ~3 KB for LSE, dPsum, and mbarriers, leaving **224 KB** for tensor buffers. - **Registers**: Controlled via `setmaxnreg`. Budget per MMA warp group: - 2 WG: 240 regs/thread, minus 24 overhead = **216 usable** - 3 WG: 160 regs/thread, minus 32 overhead = **128 usable** - **GMMA atom**: Always M=64. The effective M dimension (after swap) must be divisible by 64. N dimension must be divisible by `atom_layout_n * 8`. ## Architecture: Warp Groups Each SM90 backward kernel has `num_wg + 1` warp groups (128 threads each): - **WG0** (producer): TMA loads for Q, K, V, dO, LSE, dPsum - **WG1** (producer): dQaccum store (TMA reduce-add to gmem) - **WG2..WG(num_wg)** (MMA consumers): All GEMMs For forward: `num_wg` MMA WGs + 1 producer WG. `tile_m = num_wg * 64` (no swap). ## Key Decisions ### 1. Number of Warp Groups (num_wg) | num_wg | tile_m (fwd) | Threads | Reg budget | Best for | |--------|-------------|---------|------------|----------| | 2 | 128 | 384 | 216/thread | hdim <= 128 | | 3 | 192 | 512 | 128/thread | hdim 129-192 | More WGs = larger tile_m = better M-direction parallelism, but tighter register budget and higher smem usage. ### 2. swap_AB Each MMA can optionally swap its A and B operands. This transposes the output tile, exchanging which dimension maps to M (must be divisible by 64) and which maps to N. **When to swap:** - If the natural M dimension isn't divisible by 64 but N is (e.g., tile_m=80 for SdP) - To change which operand is in registers vs shared memory **Forward**: No swap needed since tile_m = num_wg * 64 is always divisible by 64. **Backward** (5 MMAs): - **SdP** (S=Q@K^T, dP=dO@V^T): output (tile_m, tile_n). Swap if tile_m % 64 != 0. - **dKV** (dK=dS^T@Q, dV=P^T@dO): output (tile_n, hdim/hdimv). Swap if tile_n % 64 != 0 but hdim % 64 == 0. - **dQ** (dQ=dS@K): output (tile_m, hdim). Swap if tile_m % 64 != 0 but hdim % 64 == 0. ### 3. AtomLayout The `atom_layout` distributes WGs across the M and N dimensions of an MMA output. With `num_wg` MMA WGs and `atom_layout_m = A`: - M direction: A warp groups, each handling M/A rows - N direction: num_wg/A warp groups, each handling N/(num_wg/A) columns After swap, the atom layout is also swapped. **Impact on smem traffic**: More WGs in the N direction (`wg_n` larger) means each instruction reads a smaller B slice, but more instructions total read overlapping A slices. Fewer WGs in N (`wg_n` smaller) means fewer instructions but each reads a larger B slice. Typically **smaller wg_n = less total smem traffic**. ### 4. mma_dkv_is_rs (Register-Source for dKV) When `AtomLayoutMSdP == 1 && AtomLayoutNdKV == num_wg && SdP_swapAB && !dKV_swapAB`, the P and dS matrices can be kept in registers and fed directly as the A operand of dV and dK GEMMs. This: - **Eliminates sP from smem** (saves tile_m * tile_n * 2 bytes) - **Eliminates P R2S store** from smem traffic - **Eliminates A operand reads** for dK and dV GEMMs This is a significant optimization — always preferred when the conditions are met. ### 5. Pipeline Staging **Forward**: - Q: 1 stage (loaded once per n_block tile) - K, V: 2 stages (double-buffered, pipelined with TMA) - O: overlaps with Q in smem (reuses same buffer at epilogue) **Backward**: - Q: always 2 stages (double-buffered) - dO: 2 stages if smem allows (matches Q pipeline), else 1 stage - PdS: 1 stage - K, V: persistent in smem (loaded once per n_block) ## Register Accounting Accumulator registers per thread per WG = `M * N / (num_wg * 128)`, where M x N is the output tile. **Forward peak registers**: - With WG overlap: `regs_S + regs_P + regs_O` (S, P in bf16, O all live) - Without overlap: `regs_S + regs_O` (S and O alternate, P reuses S regs) Where `regs_P = regs_S / 2` (bf16 vs f32). **Backward peak registers**: - `max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV` - S and dP accumulators are both live (S needed for softmax while dP computes) - dQ reuses S+dP register space after they're consumed - dK and dV accumulate across m_block iterations ## SMEM Accounting Sum of tensor buffers (ignoring alignment padding, which is small): **Forward**: `max(sQ, sO) + sK*2 + sV*2 + sP` - sQ = tile_m * hdim * 2 - sK = tile_n * hdim * 2 * 2 stages - sV = tile_n * hdimv * 2 * 2 stages - sO = tile_m * hdimv * 2 (overlaps with sQ) - sP = tile_m * tile_n * 2 (0 if RS) **Backward**: `sQ*2 + sK + sV + sdO*dO_stage + sP + sdS + sdQaccum` - sQ = tile_m * hdim * 2 * 2 stages - sK = tile_n * hdim * 2 - sV = tile_n * hdimv * 2 - sdO = tile_m * hdimv * 2 * dO_stage - sP = tile_m * tile_n * 2 (0 if mma_dkv_is_rs) - sdS = tile_m * tile_n * 2 - sdQaccum = tile_m * hdim * 4 (f32) ## SMEM Traffic Per-iteration smem bandwidth consumed. Each GMMA instruction reads: - **A operand**: 64 * K_red * 2 bytes (0 if register-source) - **B operand**: (N_eff / wg_n) * K_red * 2 bytes Total instructions = (M_eff / 64) * wg_n. Each instruction independently reads A and B from smem. Additional traffic: R2S stores for P, dS (bf16), dQ smem store + TMA load (f32). **Traffic per block** (traffic / (tile_m * tile_n)) normalizes across tile sizes for comparison. Lower is better. ## Example Configs ### hdim=128 (Forward) Best: tile_m=128, tile_n=192, RS, 2 WG. 224K smem, 9.3 tr/blk. ### hdim=128 (Backward, non-causal) C++ FA3 config: tile_m=80, tile_n=128, SdP_swap=T, dKV_swap=F, dQ_swap=T, aSdP=1, adKV=2. mma_dkv_is_rs=True. 204K smem, 208 regs, 39.6 tr/blk. ### hdim=192 (Backward) 3 WG, tile_m=64, tile_n=96, SdP_swap=F, dKV_swap=T, adKV=1 or 3. 216K smem, 128 regs. This is the only feasible tile_n > 64 for hdim=192 due to register pressure. ### hdim=192, hdimv=128 (DeepSeek shape) With 3 WG: need AtomLayoutNdKV=3 (since hdimv=128 not divisible by 3). tile_n=96, 212K smem. With 2 WG: tile_n=112 feasible at 210K smem, or tile_n=64 at 168K smem. ================================================ FILE: AI/SM90_R2P_MASKING_SASS.md ================================================ # SM90 FWD R2P Masking — SASS Investigation ## SASS Instruction Counts (hdim=128, seqlen=113, tile_n=128) With tile_n=128, SM90 has 32 accumulator elements per row (1 chunk of 32). ### Non-causal (seqlen-only masking) | Metric | Old (no R2P) | New (R2P) | Delta | |--------|-------------|-----------|-------| | **Total instructions** | 3104 | 3072 | **-32 (-1%)** | | R2P | 0 | 4 | +4 | | FSEL | 70 | 70 | 0 | | ISETP | 55 | 22 | **-33** | | SHF | 69 | 73 | +4 | | LOP3 | 51 | 56 | +5 | R2P replaces 33 ISETP (integer set-predicate) instructions with 4 R2P + a few LOP3/SHF. Net savings: 32 instructions. The 4 R2P instructions each convert one byte of a 32-bit bitmask into 7 predicates, covering all 32 elements (4 × 8 bits = 32). ### Causal | Metric | Old (no R2P) | New (R2P) | Delta | |--------|-------------|-----------|-------| | **Total instructions** | 5008 | 4857 | **-151 (-3%)** | | R2P | 0 | 24 | +24 | | FSEL | 200 | 200 | 0 | | ISETP | 225 | 22 | **-203** | | SHF | 104 | 105 | +1 | | LOP3 | 81 | 105 | +24 | Much larger savings. The causal kernel applies masking per-row (each row has a different col_limit), so it has many more masking operations. 24 R2P instructions replace 203 ISETP instructions, saving 151 total. ### Local (sliding window, wl=64 wr=0) | Metric | Old (no R2P) | New (R2P) | Delta | |--------|-------------|-----------|-------| | **Total instructions** | 7296 | 6217 | **-1079 (-15%)** | | R2P | 0 | 32 | +32 | | FSEL | 522 | 266 | **-256** | | ISETP | 554 | 22 | **-532** | | SHF | 115 | 73 | -42 | | LOP3 | 96 | 56 | -40 | Dramatic savings. Local masking has two bounds (left + right) per row, doubling the masking work. R2P eliminates 532 ISETP and 256 FSEL instructions, saving 1079 total (15% of kernel). ## How R2P Works in SASS The compiler generates this pattern: ``` SHF.R.U32.HI R9, RZ, R9, R16 ; shift to create bitmask R2P PR, R9, 0x7f ; byte 0 → predicates P0-P6 FSEL R15, R36, -INF, P6 ; apply P6: keep or mask to -inf R2P PR, R9.B1, 0x7f ; byte 1 → predicates P0-P6 FSEL R52, R52, -INF, P6 ; apply P6 R2P PR, R9.B2, 0x7f ; byte 2 ... R2P PR, R9.B3, 0x7f ; byte 3 ``` Each `R2P` converts 7 bits of a register byte into 7 predicate registers simultaneously (1 instruction instead of 7 `ISETP`). The subsequent `FSEL` instructions use these predicates for conditional masking. ### Handling the leftover bits (32 is not divisible by 7) The `0x7f` immediate tells R2P to map bits 0-6 of each byte to P0-P6, but bit 7 (the MSB of each byte) is not covered. For 32 elements across 4 bytes, that's 4 leftover elements (bits 7, 15, 23, 31). The compiler handles these with separate `LOP3.LUT` or `ISETP` instructions: ``` R2P PR, R12, 0x7f ; bits 0-6 → P0-P6 (7 elements) 14× FSEL using P0-P6 ; apply to 7 cols × 2 rows LOP3.LUT P0, RZ, R12, 0x80, ... ; test bit 7 (1 element) 2× FSEL using P0 R2P PR, R12.B1, 0x7f ; bits 8-14 → P0-P6 (7 elements) 14× FSEL using P0-P6 LOP3.LUT P1, RZ, R12, 0x8000, ..; test bit 15 (1 element) 2× FSEL using P1 R2P PR, R12.B2, 0x7f ; bits 16-22 → P0-P6 (7 elements) 14× FSEL using P0-P6 LOP3.LUT P0, RZ, R12, 0x800000,..; test bit 23 (1 element) 2× FSEL using P0 R2P PR, R12.B3, 0x7f ; bits 24-30 → P0-P6 (7 elements) 14× FSEL using P0-P6 ISETP.GT P0, R12, -1 ; test bit 31 (sign bit) (1 element) 2× FSEL using P0 ``` Total: 4×7 = 28 elements via R2P + 4 elements via LOP3/ISETP = 32. Each R2P replaces 7 ISETP with 1 instruction, so net savings is `(7-1) × 4 = 24` predicate-generation instructions per mask application. Additionally, ptxas can overlap R2P with FSEL since they write to separate predicate registers. ## Performance Impact | Case | Old (ms) | New (ms) | Speedup | |------|----------|----------|---------| | Causal hdim=64 s=8192 | 2.463 | 2.473 | ~0% | | Causal hdim=128 s=8192 | 1.937 | 1.944 | ~0% | | Local hdim=64 s=8192 | 0.394 | 0.346 | **+14%** | | Local hdim=128 s=8192 | 0.237 | 0.222 | **+7%** | | Non-causal hdim=128 s=4096 | 1.742 | 1.728 | ~1% | Causal sees no perf gain despite fewer instructions because masking is a tiny fraction of total work (dominated by WGMMA). Local sees significant gains because the sliding window has many partially-masked blocks where masking overhead matters more. ================================================ FILE: AI/VARLEN_PREPROCESS_TILE_BUG.md ================================================ # Varlen Preprocess Tile Mismatch Bug ## Summary `SeqlenInfo.create` in `flash_bwd_preprocess.py` defaulted `tile=128`, but the backward kernel uses `tile_m=m_block_size` (e.g. 64 for causal SM90). This caused the preprocess to zero dq_accum and write lse_log2/dpsum at wrong padded offsets for all batches after batch 0. ## How padded_offset works For varlen, buffers like dq_accum are laid out with tile-aligned gaps between sequences: ``` padded_offset_q = ((offset_q + batch_idx * tile_m) // tile_m) * tile_m ``` The gap size depends on `tile_m`. With `tile_m=64` vs `tile_m=128`, batch 1 at `offset_q=128` gets: - tile=64: padded_offset = ((128 + 64) // 64) * 64 = **192** - tile=128: padded_offset = ((128 + 128) // 128) * 128 = **256** The preprocess was zeroing at 256, the backward was writing at 192. ## Symptoms - Tests pass in isolation (torch.empty gets clean memory) - Tests fail when run in sequence (CUDA memory caching reuses NaN-polluted memory) - dq_accum valid positions contain NaN after backward kernel - `torch.zeros` for dq_accum masks the bug (zeroes everywhere, including the "right" offsets) - compute-sanitizer shows 0 errors (addresses are valid, just wrong offsets within the buffer) ## Fix ```python # flash_bwd_preprocess.py line 216 # Before: seqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ) # After: seqlen = SeqlenInfo.create(batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m) ``` ## Lesson Any code computing `padded_offset` for varlen buffers must use the same tile size as the kernel that allocated and accesses those buffers. The `SeqlenInfo.create` default `tile=128` is a trap when `m_block_size != 128`. ================================================ FILE: AI/racecheck_repro_1d_bulk.py ================================================ """Minimal reproducer: cp.async.bulk (raw address) triggers racecheck hazard. Warp 0 loads via cp.async.bulk, warp 1 reads from smem after mbarrier wait. Pipeline is correctly synchronized but racecheck reports 1 error. python AI/racecheck_repro_1d_bulk.py # correctness CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_bulk.py # 1 error """ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync from cutlass.cute.runtime import from_dlpack from cutlass import Float32, Int32 import cutlass.pipeline from cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state import cuda.bindings.driver as cuda import torch N_BLKS, TILE = 4, 128 N_STG = 2 @cute.kernel def kernel(g_src: cute.Tensor, g_dst: cute.Tensor): smem = cutlass.utils.SmemAllocator() s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128) s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8) tidx, _, _ = cute.arch.thread_idx() warp, lane = tidx // 32, tidx % 32 pipe = PipelineTmaAsync.create( barrier_storage=s_mbar.iterator, num_stages=N_STG, producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1), consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1), tx_count=TILE * 4, defer_sync=False, ) src = cute.local_tile(g_src, (TILE,), (None,)) dst = cute.local_tile(g_dst, (TILE,), (None,)) if warp == 0: ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG) for blk in cutlass.range(N_BLKS, unroll=1): pipe.producer_acquire(ps) atom = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) with cute.arch.elect_one(): cute.copy(atom, src[None, blk], s[None, ps.index], mbar_ptr=pipe.producer_get_barrier(ps)) ps.advance() pipe.producer_tail(ps) if warp == 1: cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG) for blk in cutlass.range(N_BLKS, unroll=1): pipe.consumer_wait(cs) for i in cutlass.range_constexpr(TILE // 32): dst[lane + i * 32, blk] = s[lane + i * 32, cs.index] cute.arch.fence_view_async_shared() cute.arch.sync_warp() # Ned sync_warp as only 1 thread will signal in consumer_release pipe.consumer_release(cs) cs.advance() @cute.jit def go(g_src, g_dst, stream): kernel(g_src, g_dst).launch(grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream) if __name__ == "__main__": src = torch.arange(TILE * N_BLKS, device="cuda", dtype=torch.float32) dst = torch.zeros_like(src) go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16), cuda.CUstream(torch.cuda.current_stream().cuda_stream)) torch.cuda.synchronize() assert torch.equal(src, dst), f"FAIL: max diff={torch.abs(src - dst).max().item()}" print("PASS") ================================================ FILE: AI/racecheck_repro_1d_tensor.py ================================================ """Minimal reproducer: cp.async.bulk.tensor.1d (descriptor TMA) passes racecheck. Same pipeline as racecheck_repro_1d_bulk.py but uses make_tiled_tma_atom to create a TMA descriptor, which generates cp.async.bulk.tensor.1d PTX. python AI/racecheck_repro_1d_tensor.py # correctness CUTE_DSL_LINEINFO=1 compute-sanitizer --tool=racecheck python AI/racecheck_repro_1d_tensor.py # 0 hazards """ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync from cutlass.cute.runtime import from_dlpack from cutlass import Float32, Int32 import cutlass.pipeline from cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state import cuda.bindings.driver as cuda import torch N_BLKS, TILE = 4, 128 N_STG = 2 @cute.kernel def kernel(g_dst: cute.Tensor, tma_atom: cute.CopyAtom, tma_tensor: cute.Tensor): smem = cutlass.utils.SmemAllocator() s = smem.allocate_tensor(Float32, cute.make_layout((TILE, N_STG)), byte_alignment=128) s_mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2 * N_STG), byte_alignment=8) tidx, _, _ = cute.arch.thread_idx() warp, lane = tidx // 32, tidx % 32 pipe = PipelineTmaAsync.create( barrier_storage=s_mbar.iterator, num_stages=N_STG, producer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1), consumer_group=cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread, 1), tx_count=TILE * 4, defer_sync=False, ) tma_s, tma_g = cpasync.tma_partition( tma_atom, Int32(0), cute.make_layout(1), cute.group_modes(s, 0, 1), cute.group_modes(cute.local_tile(tma_tensor, (TILE,), (None,)), 0, 1), ) dst = cute.local_tile(g_dst, (TILE,), (None,)) if warp == 0: with cute.arch.elect_one(): cpasync.prefetch_descriptor(tma_atom) if warp == 0: ps = make_pipeline_state(cutlass.pipeline.PipelineUserType.Producer, N_STG) for blk in cutlass.range(N_BLKS, unroll=1): pipe.producer_acquire(ps) cute.copy(tma_atom, tma_g[None, blk], tma_s[None, ps.index], tma_bar_ptr=pipe.producer_get_barrier(ps)) ps.advance() pipe.producer_tail(ps) if warp == 1: cs = make_pipeline_state(cutlass.pipeline.PipelineUserType.Consumer, N_STG) for blk in cutlass.range(N_BLKS, unroll=1): pipe.consumer_wait(cs) for i in cutlass.range_constexpr(TILE // 32): dst[lane + i * 32, blk] = s[lane + i * 32, cs.index] cute.arch.fence_view_async_shared() cute.arch.sync_warp() # Ned sync_warp as only 1 thread will signal in consumer_release pipe.consumer_release(cs) cs.advance() @cute.jit def go(g_src, g_dst, stream): tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), g_src, cute.make_layout(TILE), (TILE,), ) kernel(g_dst, tma_atom, tma_tensor).launch( grid=[1, 1, 1], block=[64, 1, 1], smem=4096, stream=stream, ) if __name__ == "__main__": src = torch.arange(TILE * N_BLKS, device="cuda", dtype=torch.float32) dst = torch.zeros_like(src) go(from_dlpack(src, assumed_align=16), from_dlpack(dst, assumed_align=16), cuda.CUstream(torch.cuda.current_stream().cuda_stream)) torch.cuda.synchronize() assert torch.equal(src, dst), f"FAIL: max diff={torch.abs(src - dst).max().item()}" print("PASS") ================================================ FILE: AUTHORS ================================================ Tri Dao, trid@cs.stanford.edu ================================================ FILE: CLAUDE.md ================================================ # CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Project Overview FlashAttention-4 (FA4) — fast, memory-efficient exact attention kernels written in Python using CuTeDSL (NVIDIA CUTLASS DSL). Kernels are compiled to PTX/CUBIN at runtime. Targets Hopper (SM90) and Blackwell (SM100/SM110) GPUs. Package name: `flash-attn-4`. The repository also contains older generations (FA2 in top-level `csrc/`, FA3 in `hopper/`) but active development is on FA4 in `flash_attn/cute/`. ## Build & Install ```bash pip install flash-attn-4 # or dev install: pip install -e "flash_attn/cute[dev]" ``` Dependencies: `nvidia-cutlass-dsl>=4.4.1`, `torch`, `einops`, `apache-tvm-ffi`, `quack-kernels>=0.2.10`. ## Running Tests ```bash pytest tests/cute/test_flash_attn.py pytest tests/cute/test_flash_attn.py -k "test_flash_attn_output" -x # single test pytest tests/cute/test_flash_attn_varlen.py pytest tests/cute/test_mask_mod.py pytest tests/cute/test_score_mod.py pytest tests/cute/test_block_sparsity.py ``` ### Fast two-pass testing Compilation dominates test time. The fast workflow separates compilation (parallel, no GPU needed) from execution (uses cached binaries): ```bash # Pass 1: compile all kernels in parallel using FakeTensorMode (no GPU memory allocation) FLASH_ATTENTION_FAKE_TENSOR=1 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -n 64 -x tests/cute/test_flash_attn.py # Pass 2: run tests using cached compiled kernels FLASH_ATTENTION_FAKE_TENSOR=0 FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1 pytest -x tests/cute/test_flash_attn.py ``` - `FLASH_ATTENTION_FAKE_TENSOR=1` — uses PyTorch FakeTensorMode to compile kernels without allocating GPU memory or running them. - `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` — enables persistent disk cache at `/tmp/${USER}/flash_attention_cute_dsl_cache/`. - `-n 256` — pytest-xdist parallel workers (only useful in the compilation pass). Tests are parametrized over dtype (fp16/bf16), head dimension (64, 96, 128), sequence length, causal/non-causal, and MHA/GQA/MQA. If you get OOM errors running tests or benchmarks, use `nvidia-smi` to find a free GPU and select it with `CUDA_VISIBLE_DEVICES=`. ## Linting Pre-commit uses ruff on `flash_attn/cute/` files. Large kernel files (`flash_bwd.py`, `flash_fwd.py`, `flash_fwd_sm100.py`, `interface.py`) are excluded from auto-formatting. ```bash ruff check flash_attn/cute/ --fix ruff format flash_attn/cute/ ``` ## Code Architecture ### Public API (`flash_attn/cute/interface.py`) Two entry points exported from `flash_attn/cute/__init__.py`: - `flash_attn_func(q, k, v, ...)` — standard attention - `flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)` — variable-length Key parameters: `causal`, `window_size_left/right`, `softmax_scale`, `softcap`, `score_mod`, `mask_mod`, `block_sparse_tensors`, `num_splits`, `pack_gqa`, `m_block_size`, `n_block_size`, `num_threads`. Tensor layout: `(batch, seqlen, num_heads, head_dim)`, last dim contiguous, 16-byte aligned. ### Forward Kernels - `flash_fwd.py` — `FlashAttentionForwardSm90`: Hopper forward. No SplitKV or paged KV. - `flash_fwd_sm100.py` — `FlashAttentionForwardSm100`: Blackwell forward. Full features including SplitKV, paged KV cache, persistent kernels, 2CTA instructions. - `flash_fwd_combine.py` — `FlashAttentionForwardCombine`: merges SplitKV partial results. ### Backward Kernels - `flash_bwd.py` — `FlashAttentionBackwardSm80`: Ampere backward (base). - `flash_bwd_sm90.py` — `FlashAttentionBackwardSm90`: Hopper backward. - `flash_bwd_sm100.py` — `FlashAttentionBackwardSm100`: Blackwell backward with 2CTA and block sparse support. - `flash_bwd_preprocess.py` / `flash_bwd_postprocess.py` — auxiliary backward kernels. ### Core Abstractions - `softmax.py` — Online softmax with row_max/row_sum tracking, score modifier support. - `mask.py` — `AttentionMask`: causal, local/sliding window, block sparse, mask_mod application. - `block_info.py` — `BlockInfo`: tile dimensions, n/m block range computation for causal/local masking. - `seqlen_info.py` — `SeqlenInfoQK`: sequence length and offset tracking for varlen. - `pipeline.py` — `PipelineStateSimple`: circular buffer index/phase management for pipelined loads. - `tile_scheduler.py` — Tile scheduling strategies (single tile, varlen-aware, persistent). - `copy_utils.py` — Type-converting copies, shared-to-register loads, TMA copy atoms. - `named_barrier.py` — Named barrier enums for warp synchronization. ### Architecture-Specific Helpers - `hopper_helpers.py` — SM90 warp-group GEMM, shared memory layout creation, fence/commit/wait. - `blackwell_helpers.py` — SM100 UMMA-based GEMM, PTX-optimized paths, 2CTA support. - `mma_sm100_desc.py` — Hardware MMA descriptor enums (formats, saturation, scaling). ### Other Components - `pack_gqa.py` — Packs multiple Q heads per KV head for efficient GQA. - `paged_kv.py` — `PagedKVManager`: paged KV cache with TMA support. - `fast_math.py` — exp2 polynomial coefficients, softcap score_mod creation. - `utils.py` — Hash functions for compile cache keys, warp reductions, predicates. - `cache_utils.py` — JIT compilation cache management. - `cute_dsl_utils.py` — Patched `cute.compile` that optionally dumps SASS. ### Compilation & Caching Kernels are JIT-compiled. Cache key includes dtype, head_dim, causal, mask/score_mod hashes, architecture, block sizes. Caching levels: in-memory LRU + optional disk cache via `get_jit_cache()`. Env vars: `CUTE_CUBIN_PATH` (dump CUBIN/SASS), `CUTE_DSL_KEEP_PTX=1` (inspect PTX), `CUTE_DSL_PTXAS_PATH` (custom ptxas). ## Key Patterns - Compile-time constants use `cutlass.Constexpr[type]` for kernel specialization. - Score/mask modifiers are user-defined `@cute.jit` callables injected into the kernel at compile time. - Forward execution: load Q tile → loop over K/V blocks (pipelined) → online softmax accumulation → store O and LSE. - 2CTA instructions (SM100, hdim=128): both CTAs in a cluster coordinate via shared mbarriers; tx_count must be multiplied by `cta_group_size`. ## Debugging GPU Kernels See `AI/DEBUG_2CTA.md` for kernel hang/deadlock debugging (printf bisection, pipeline barrier analysis, 2CTA pitfalls). See `AI/RACECHECK_TMA_HAZARD.md` for `compute-sanitizer` false positives with `cp.async.bulk`. Key tools: - `cute.printf` with thread guards (`tidx % 32 == 0`, `elect_one()`) for targeted output - `compute-sanitizer --tool=racecheck` (beware false positives with raw TMA) - `CUTE_DSL_KEEP_PTX=1` and `CUTE_DSL_LINEINFO=1` for PTX inspection and sanitizer source mapping ================================================ FILE: LICENSE ================================================ BSD 3-Clause License Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: MANIFEST.in ================================================ recursive-include csrc *.cu recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp recursive-include csrc *.hpp recursive-include csrc *.py recursive-include flash_attn *.cu recursive-include flash_attn *.h recursive-include flash_attn *.cuh recursive-include flash_attn *.cpp recursive-include flash_attn *.hpp ================================================ FILE: Makefile ================================================ clean_dist: rm -rf dist/* create_dist: clean_dist python setup.py sdist upload_package: create_dist twine upload dist/* ================================================ FILE: README.md ================================================ # FlashAttention This repository provides the official implementation of FlashAttention and FlashAttention-2 from the following papers. **FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré Paper: https://arxiv.org/abs/2205.14135 IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. ![FlashAttention](assets/flashattn_banner.jpg) **FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** Tri Dao Paper: https://tridao.me/publications/flash2/flash2.pdf ![FlashAttention-2](assets/flashattention_logo.png) ## Usage We've been very happy to see FlashAttention being widely adopted in such a short time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) contains a partial list of places where FlashAttention is being used. FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite and credit FlashAttention if you use it. ## FlashAttention-3 beta release FlashAttention-3 is optimized for Hopper GPUs (e.g. H100). Blogpost: https://tridao.me/blog/2024/flash3/ Paper: https://tridao.me/publications/flash3/flash3.pdf ![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png) This is a beta release for testing / benchmarking before we integrate that with the rest of the repo. Currently released: - FP16 / BF16 forward and backward, FP8 forward Requirements: H100 / H800 GPU, CUDA >= 12.3. We highly recommend CUDA 12.8 for best performance. To install: ```sh cd hopper python setup.py install ``` To run the test: ```sh export PYTHONPATH=$PWD pytest -q -s test_flash_attn.py ``` Once the package is installed, you can import it as follows: ```python import flash_attn_interface flash_attn_interface.flash_attn_func() ``` ## FlashAttention-4 (CuTeDSL) FlashAttention-4 is written in CuTeDSL and optimized for Hopper and Blackwell GPUs (e.g. H100, B200). To install: ```sh pip install flash-attn-4 ``` Once installed, you can use it as follows: ```python from flash_attn.cute import flash_attn_func out = flash_attn_func(q, k, v, causal=True) ``` ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit - PyTorch 2.2 and above. - `packaging` Python package (`pip install packaging`) - `psutil` Python package (`pip install psutil`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. \* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja --version` then `echo $?` should return exit code 0). If not (sometimes `ninja --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall `ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, compiling can take a very long time (2h) since it does not use multiple CPU cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit. **To install:** ```sh pip install flash-attn --no-build-isolation ``` Alternatively you can compile from source: ```sh python setup.py install ``` If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might run too many parallel compilation jobs that could exhaust the amount of RAM. To limit the number of parallel compilation jobs, you can set the environment variable `MAX_JOBS`: ```sh MAX_JOBS=4 pip install flash-attn --no-build-isolation ``` **Interface:** `src/flash_attention_interface.py` ### NVIDIA CUDA Support **Requirements:** - CUDA 12.0 and above. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) container from Nvidia, which has all the required tools to install FlashAttention. FlashAttention-2 with CUDA currently supports: 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now. 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. ### AMD ROCm Support ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2. **Requirements:** - ROCm 6.0 and above. We recommend the [Pytorch](https://hub.docker.com/r/rocm/pytorch) container from ROCm, which has all the required tools to install FlashAttention. #### Composable Kernel Backend FlashAttention-2 ROCm CK backend currently supports: 1. MI200x, MI250x, MI300x, and MI355x GPUs. 2. Datatype fp16 and bf16 3. Both forward's and backward's head dimensions up to 256. #### Triton Backend The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress. The Triton backend kernels are provided by the [aiter](https://github.com/ROCm/aiter) package, included as a git submodule at `third_party/aiter` and automatically installed during setup. To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Flash Attention: ```sh cd flash-attention FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . ``` To use a specific aiter commit (e.g., for testing or development): ```sh cd flash-attention cd third_party/aiter && git fetch origin && git checkout && cd ../.. FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . ``` To run the tests (note: full suite takes hours): ```sh FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py ``` The Triton backend uses a default kernel configuration optimized for determinism and reasonable performance across workloads. For peak throughput, enable `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` to search for optimal settings, which incurs a one-time warmup cost. Alternativly, if _not_ autotuning, `FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON` may be used to set a single triton config overriding the hardcoded defaults for `attn_fwd`. E.g. ```sh FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON='{"BLOCK_M":128,"BLOCK_N":64,"waves_per_eu":1,"PRE_LOAD_V":false,"num_stages":1,"num_warps":8}' ``` For a quick start with Docker: ```dockerfile FROM rocm/pytorch:latest WORKDIR /workspace # build flash attention with triton backend RUN git clone https://github.com/Dao-AILab/flash-attention &&\ cd flash-attention &&\ FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pip install --no-build-isolation . # set working dir WORKDIR /workspace/flash-attention # set env variable to use triton backend ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" ``` Build and run: ```sh docker build -t flash-attn-triton . docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton ``` ## How to use FlashAttention The main functions implement scaled dot product attention (softmax(Q @ K^T * softmax_scale) @ V): ```python from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ``` ```python flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` ```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` ```python def flash_attn_with_kvcache( q, k_cache, v_cache, k=None, v=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window rotary_interleaved=True, alibi_slopes=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from k and v. This is useful for incremental decoding: you can pass in the cached keys/values from the previous step, and update them with the new keys/values from the current step, and do attention with the updated cache, all in 1 kernel. If you pass in k / v, you must make sure that the cache is large enough to hold the new values. For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Note: Does not support backward pass. Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) page_block_size must be a multiple of 256. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` To see how these functions are used in a multi-head attention layer (which includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). ### Using with 🤗 Kernels If your hardware environment belongs to any of the above-mentioned, you can also use the [`kernels` library](https://github.com/huggingface/kernels) to use Flash Attention 2 and 3 right away. ```py # pip install kernels from kernels import get_kernel # FA2 fa_module = get_kernel("kernels-community/flash-attn2", version=1) flash_attn_func = fa_module.flash_attn_func # FA3 fa3_module = get_kernel("kernels-community/flash-attn3", version=1) flash_attn_func = fa3_module.flash_attn_func ``` ## Changelog ### 2.0: Complete rewrite, 2x faster Upgrading from FlashAttention (1.x) to FlashAttention-2 These functions have been renamed: - `flash_attn_unpadded_func` -> `flash_attn_varlen_func` - `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` - `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` If the inputs have the same sequence lengths in the same batch, it is simpler and faster to use these functions: ```python flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ``` ```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` ### 2.1: Change behavior of causal flag If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the bottom right corner of the attention matrix, instead of the top-left corner. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: v2.0: 1 0 0 0 0 1 1 0 0 0 v2.1: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: v2.0: 1 0 1 1 1 1 1 1 1 1 v2.1: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. ### 2.2: Optimize for inference Optimize for inference (iterative decoding) when query has very small sequence length (e.g., query sequence length = 1). The bottleneck here is to load KV cache as fast as possible, and we split the loading across different thread blocks, with a separate kernel to combine results. See the function `flash_attn_with_kvcache` with more features for inference (perform rotary embedding, updating KV cache inplace). Thanks to the xformers team, and in particular Daniel Haziza, for this collaboration. ### 2.3: Local (i.e., sliding window) attention Implement sliding window attention (i.e., local attention). Thanks to [Mistral AI](https://mistral.ai/) and in particular Timothée Lacroix for this contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. ### 2.4: ALiBi (attention with linear bias), deterministic backward pass. Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution. Implement deterministic backward pass. Thanks to engineers from [Meituan](www.meituan.com) for this contribution. ### 2.5: Paged KV cache. Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)). Thanks to @beginlner for this contribution. ### 2.6: Softcapping. Support attention with softcapping, as used in Gemma-2 and Grok models. Thanks to @Narsil and @lucidrains for this contribution. ### 2.7: Compatibility with torch compile Thanks to @ani300 for this contribution. ## Performance We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We currently have benchmarks for these GPUs: * [A100](#a100) * [H100](#h100) ### A100 We display FlashAttention speedup using these parameters: * Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). * Sequence length 512, 1k, 2k, 4k, 8k, 16k. * Batch size set to 16k / seqlen. #### Speedup ![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) #### Memory ![FlashAttention memory](assets/flashattn_memory.jpg) We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. We see 10X memory savings at sequence length 2K, and 20X at 4K. As a result, FlashAttention can scale to much longer sequence lengths. ### H100 ![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) ## Full model code and training script We have released the full GPT model [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x compared to the baseline implementation from Huggingface, reaching up to 225 TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need any activation checkpointing). We also include a training [script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to train GPT2 on Openwebtext and GPT3 on The Pile. ## Triton implementation of FlashAttention Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py As Triton is a higher-level language than CUDA, it might be easier to understand and experiment with. The notations in the Triton implementation are also closer to what's used in our paper. We also have an experimental implementation in Triton that support attention bias (e.g. ALiBi): https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py ## Tests We test that FlashAttention produces the same output and gradient as a reference implementation, up to some numerical tolerance. In particular, we check that the maximum numerical error of FlashAttention is at most twice the numerical error of a baseline implementation in Pytorch (for different head dimensions, input dtype, sequence length, causal / non-causal). To run the tests: ```sh pytest -q -s tests/test_flash_attn.py ``` ## When you encounter issues This new release of FlashAttention-2 has been tested on several GPT-style models, mostly on A100 GPUs. If you encounter bugs, please open a GitHub Issue! ## Tests To run the tests: ```sh pytest tests/test_flash_attn_ck.py ``` ## Citation If you use this codebase, or otherwise found our work valuable, please cite: ``` @inproceedings{dao2022flashattention, title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2022} } @inproceedings{dao2023flashattention2, title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, author={Dao, Tri}, booktitle={International Conference on Learning Representations (ICLR)}, year={2024} } ``` ================================================ FILE: benchmarks/bench_sm90.py ================================================ #!/usr/bin/env python """Unified SM90 benchmark for forward and backward passes. Usage: # Default: bench fwd+bwd for hdim 64,96,128 at seqlen 8192 python benchmarks/bench_sm90.py # Forward only, specific hdims python benchmarks/bench_sm90.py --direction fwd --hdim 64,96 # Backward only python benchmarks/bench_sm90.py --direction bwd --hdim 128 # Custom seqlens and batch size python benchmarks/bench_sm90.py --seqlen 1024,2048,4096,8192 --batch 0 # Sweep tile sizes for fwd python benchmarks/bench_sm90.py --sweep-tiles --hdim 96 # Sweep tile sizes for fwd (all hdims including 192, 256) python benchmarks/bench_sm90.py --sweep-tiles --hdim 64,96,128,192,256 # Sweep RS/overlap variants python benchmarks/bench_sm90.py --sweep-rs-overlap --hdim 64,96 # Compare old vs new configs python benchmarks/bench_sm90.py --compare-configs # Sweep backward optimizations (V_in_regs, mma_dkv_is_rs, pipeline sharing) python benchmarks/bench_sm90.py --sweep-bwd-opts --hdim 64,128 # Causal only, more reps and warmup python benchmarks/bench_sm90.py --causal-only --rep 50 --warmup 10 """ import argparse import time import torch import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd # ── Helpers ──────────────────────────────────────────────────────────────── def parse_int_k(s): """Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.""" s = s.strip().lower() if s.endswith("k"): return int(s[:-1]) * 1024 return int(s) def csv_ints(s): """Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'.""" return [parse_int_k(x) for x in s.split(",")] def parse_headdims(s): """Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v. Examples: '128' -> [(128, 128)] '192-128' -> [(192, 128)] '64,128,192' -> [(64, 64), (128, 128), (192, 192)] '64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)] """ result = [] for item in s.split(","): if "-" in item: parts = item.split("-") result.append((int(parts[0]), int(parts[1]))) else: hdim = int(item) result.append((hdim, hdim)) return result def nheads_for_hdim(h): return 32 if h <= 64 else (16 if h <= 192 else 8) def fwd_flops(batch, nheads, seqlen, hdim, hdim_v=None, causal=False): if hdim_v is None: hdim_v = hdim avg_seqlen = seqlen / 2 if causal else seqlen return batch * nheads * 2 * seqlen * avg_seqlen * (hdim + hdim_v) def bwd_flops(batch, nheads, seqlen, hdim, causal=False, hdim_v=None): return 2.5 * fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal) def get_causals(args): if args.causal_only: return [True] if args.non_causal_only: return [False] return [False, True] def auto_batch(seqlen, batch_arg, total_tokens=32768): return batch_arg if batch_arg > 0 else max(1, total_tokens // seqlen) # ── Core bench functions ────────────────────────────────────────────────── def bench_fwd(batch, seqlen, nheads, hdim, causal, tile_m=None, tile_n=None, mma_pv_is_rs=None, intra_wg_overlap=None, check_correctness=True, warmup=5, rep=30, hdim_v=None): """Benchmark forward pass. Returns (ms, tflops, max_diff_or_error).""" if hdim_v is None: hdim_v = hdim q = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device="cuda") k = torch.randn(batch, seqlen, nheads, hdim, dtype=torch.bfloat16, device="cuda") v = torch.randn(batch, seqlen, nheads, hdim_v, dtype=torch.bfloat16, device="cuda") kwargs = dict(softmax_scale=hdim ** -0.5, causal=causal) if tile_m is not None and tile_n is not None: kwargs["tile_mn"] = (tile_m, tile_n) if mma_pv_is_rs is not None: kwargs["mma_pv_is_rs"] = mma_pv_is_rs if intra_wg_overlap is not None: kwargs["intra_wg_overlap"] = intra_wg_overlap try: out, _lse = _flash_attn_fwd(q, k, v, **kwargs) except Exception as e: return None, None, str(e)[:80] max_diff = None if check_correctness: q_ref = q.transpose(1, 2).float() k_ref = k.transpose(1, 2).float() v_ref = v.transpose(1, 2).float() out_ref = F.scaled_dot_product_attention(q_ref, k_ref, v_ref, is_causal=causal) out_ref = out_ref.transpose(1, 2).to(torch.bfloat16) max_diff = (out.float() - out_ref.float()).abs().max().item() for _ in range(warmup): _flash_attn_fwd(q, k, v, **kwargs) torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(rep): _flash_attn_fwd(q, k, v, **kwargs) end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) / rep tflops = fwd_flops(batch, nheads, seqlen, hdim, hdim_v=hdim_v, causal=causal) / ms / 1e9 return ms, tflops, max_diff def bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=5, rep=30, hdim_v=None, **bwd_kwargs): """Benchmark backward pass. Returns (ms, tflops, None_or_error).""" if hdim_v is None: hdim_v = hdim q = torch.randn(batch, seqlen, nheads, hdim, device="cuda", dtype=torch.bfloat16) k = torch.randn(batch, seqlen, nheads, hdim, device="cuda", dtype=torch.bfloat16) v = torch.randn(batch, seqlen, nheads, hdim_v, device="cuda", dtype=torch.bfloat16) softmax_scale = hdim ** -0.5 try: out, lse = _flash_attn_fwd(q, k, v, softmax_scale=softmax_scale, causal=causal, return_lse=True) except Exception as e: return None, None, str(e)[:80] dout = torch.randn_like(out) def fn(): _flash_attn_bwd(q, k, v, out, dout, lse, softmax_scale=softmax_scale, causal=causal, **bwd_kwargs) try: fn() # compile except Exception as e: return None, None, str(e)[:80] for _ in range(warmup): fn() torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(rep): fn() end.record() torch.cuda.synchronize() ms = start.elapsed_time(end) / rep tflops = bwd_flops(batch, nheads, seqlen, hdim, causal, hdim_v=hdim_v) / ms / 1e9 return ms, tflops, None # ── Preset configs ──────────────────────────────────────────────────────── # (tile_m, tile_n, mma_pv_is_rs, intra_wg_overlap) TILE_SWEEP_CONFIGS = { 64: [ (192, 192, False, True), (192, 192, True, True), (192, 128, True, True), (192, 128, False, True), (128, 128, True, True), (128, 192, True, True), (192, 96, True, True), (192, 96, False, True), ], 96: [ (192, 144, False, True), (192, 144, True, True), (192, 128, False, True), (192, 128, True, True), (192, 96, False, True), (192, 96, True, True), (128, 128, True, True), (128, 128, False, True), ], 128: [ (128, 128, True, True), (128, 128, False, True), (128, 96, True, True), (128, 96, False, True), (128, 160, True, True), (128, 176, True, True), (128, 192, True, True), ], 192: [ (128, 64, True, True), (128, 80, True, True), (128, 96, True, True), (128, 112, True, True), (128, 128, True, True), ], 256: [ (128, 48, True, True), (128, 64, True, True), (128, 80, True, True), (128, 96, True, True), ], } RS_OVERLAP_COMBOS = [ (True, True, "RS+OL"), (True, False, "RS+noOL"), (False, True, "noRS+OL"), (False, False, "noRS+noOL"), ] COMPARE_CONFIGS = [ # (hdim, causal, (old_tile_m, old_tile_n, old_rs, old_ol), (new...)) (64, False, (192, 128, True, True), (192, 128, True, True)), (64, True, (192, 128, True, True), (192, 128, True, True)), (96, False, (192, 96, True, True), (192, 144, False, True)), (96, True, (192, 96, True, True), (192, 128, False, True)), ] def _get_default_bwd_config(headdim, causal=False): """Default SM90 backward config for a given headdim.""" if headdim <= 128: return dict( m_block_size=64 if causal else 80, n_block_size=128, num_stages_Q=2, num_stages_dO=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=not causal, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, num_threads=384, ) elif headdim <= 192: return dict( m_block_size=64, n_block_size=96, num_stages_Q=1, num_stages_dO=1, SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=True, AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, num_threads=512, ) else: return dict( m_block_size=64, n_block_size=64, num_stages_Q=1, num_stages_dO=1, SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, num_threads=384, ) # Maps optimization name -> function(headdim, causal) -> dict[label, kwargs] or None BWD_OPT_CONFIGS = { "V_in_regs": lambda hdim, causal: ( None if hdim > 128 else { "baseline (V_in_regs=False)": {**_get_default_bwd_config(hdim, causal), "V_in_regs": False}, "optimized (V_in_regs=True)": {**_get_default_bwd_config(hdim, causal), "V_in_regs": True}, } ), "mma_dkv_is_rs": lambda hdim, causal: ( None if hdim > 128 else { "baseline (AtomLayoutNdKV=1)": {**_get_default_bwd_config(hdim, causal), "AtomLayoutNdKV": 1}, "optimized (AtomLayoutNdKV=2)": {**_get_default_bwd_config(hdim, causal), "AtomLayoutNdKV": 2}, } ), "Q_dO_pipeline_sharing": lambda hdim, causal: ( None if hdim > 128 else { "baseline (dO_stage=1, separate)": {**_get_default_bwd_config(hdim, causal), "num_stages_dO": 1}, "optimized (dO_stage=2, shared)": {**_get_default_bwd_config(hdim, causal), "num_stages_dO": 2}, } ), "tile_m": lambda hdim, causal: ( None if hdim > 128 or causal else { "tile_m=64": {**_get_default_bwd_config(hdim, causal), "m_block_size": 64}, "tile_m=80": {**_get_default_bwd_config(hdim, causal), "m_block_size": 80}, } ), } # ── Run modes ───────────────────────────────────────────────────────────── def run_default(args): """Standard fwd/bwd benchmark across hdims.""" directions = [args.direction] if args.direction != "both" else ["fwd", "bwd"] for direction in directions: print(f"\n{'=' * 80}") print(f" SM90 {direction.upper()} (rep={args.rep})") print(f"{'=' * 80}") cols = f"{'hdim':>5} {'hdim_v':>6} {'causal':>6} {'batch':>5} {'seqlen':>6} {'ms':>8} {'TFLOPS':>8}" if direction == "fwd": cols += f" {'max_diff':>10}" print(cols) print("-" * 80) for hdim, hdim_v in args.hdim: nheads = nheads_for_hdim(hdim) for seqlen in args.seqlen: batch = auto_batch(seqlen, args.batch) for causal in get_causals(args): if direction == "fwd": ms, tflops, diff = bench_fwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v) else: ms, tflops, diff = bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v) if ms is not None: line = f"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {ms:>8.3f} {tflops:>8.1f}" if diff is not None: line += f" {diff:>10.6f}" print(line) else: print(f"{hdim:>5} {hdim_v:>6} {str(causal):>6} {batch:>5} {seqlen:>6} {'FAIL':>8} {'':>8} {diff}") def run_sweep_tiles(args): """Sweep tile sizes for fwd across seqlens.""" seqlens = args.seqlen for hdim, hdim_v in args.hdim: nheads = nheads_for_hdim(hdim) configs = TILE_SWEEP_CONFIGS.get(hdim, []) if not configs: print(f"No tile sweep configs for hdim={hdim}, skipping") continue for causal in get_causals(args): header = f"{'hdim':>5} {'causal':>6} {'tile_m':>6} {'tile_n':>6} {'pv_rs':>5} {'ol':>5}" for sl in seqlens: header += f" {'s=' + str(sl):>8}" print(header) print("=" * len(header)) for tile_m, tile_n, rs, ol in configs: row = f"{hdim:>5} {str(causal):>6} {tile_m:>6} {tile_n:>6} {str(rs):>5} {str(ol):>5}" for sl in seqlens: batch = auto_batch(sl, args.batch) ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal, tile_m, tile_n, rs, ol, check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v) row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}" print(row) print() def run_sweep_rs_overlap(args): """Sweep RS and intra-WG-overlap combinations for fwd.""" seqlens = args.seqlen tile_for_hdim = {64: (192, 128), 96: (192, 128), 128: (128, 128)} for hdim, hdim_v in args.hdim: nheads = nheads_for_hdim(hdim) tile_m, tile_n = tile_for_hdim.get(hdim, (128, 128)) for causal in get_causals(args): c_str = "causal" if causal else "non-causal" header = f"{'Config':<30} {'RS/OL':<12}" for sl in seqlens: header += f" {'s=' + str(sl):>8}" print(header) print("=" * len(header)) for rs, ol, rs_label in RS_OVERLAP_COMBOS: label = f"hdim{hdim} {c_str} {tile_m}x{tile_n}" row = f"{label:<30} {rs_label:<12}" for sl in seqlens: batch = auto_batch(sl, args.batch) ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal, tile_m, tile_n, rs, ol, check_correctness=False, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v) row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}" print(row) print() def run_compare_configs(args): """Compare old vs new tile configs for fwd.""" seqlens = args.seqlen header = f"{'Config':<50}" for sl in seqlens: header += f" {'s=' + str(sl):>8}" print(header) print("=" * len(header)) for hdim, causal, old, new in COMPARE_CONFIGS: nheads = nheads_for_hdim(hdim) c_str = "causal" if causal else "non-causal" for label_prefix, cfg in [("OLD", old), ("NEW", new)]: label = f"hdim{hdim} {c_str:<11} {label_prefix} {cfg[0]}x{cfg[1]} RS={cfg[2]} OL={cfg[3]}" row = f"{label:<50}" for sl in seqlens: batch = auto_batch(sl, args.batch) ms, tflops, diff = bench_fwd(batch, sl, nheads, hdim, causal, *cfg, check_correctness=False, warmup=args.warmup, rep=args.rep) row += f" {tflops:>8.1f}" if tflops else f" {'FAIL':>8}" print(row) print("-" * len(header)) def run_sweep_bwd_opts(args): """Sweep backward kernel optimizations (V_in_regs, mma_dkv_is_rs, etc.).""" seqlens = args.seqlen for opt_name, get_configs_fn in BWD_OPT_CONFIGS.items(): for causal in get_causals(args): c_str = "causal" if causal else "non-causal" has_any = False for hdim, hdim_v in args.hdim: configs = get_configs_fn(hdim, causal) if configs is None: continue if not has_any: print(f"\n{'=' * 70}") print(f"BWD Optimization: {opt_name} ({c_str})") print(f"{'=' * 70}") has_any = True nheads = nheads_for_hdim(hdim) print(f"\n hdim={hdim}:") for sl in seqlens: batch = auto_batch(sl, args.batch) f = bwd_flops(batch, nheads, sl, hdim, causal, hdim_v=hdim_v) if len(seqlens) > 1: print(f" seqlen={sl}, batch={batch}:") for label, kwargs in configs.items(): ms, tflops, err = bench_bwd(batch, sl, nheads, hdim, causal, warmup=args.warmup, rep=args.rep, hdim_v=hdim_v, **kwargs) if ms is not None: print(f" {label:40s}: {ms:6.2f} ms ({tflops:6.1f} TFLOPS)") else: print(f" {label:40s}: FAIL {err}") # ── Main ────────────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( description="Unified SM90 attention benchmark", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument("--direction", choices=["fwd", "bwd", "both"], default="both", help="Benchmark direction (default: both)") parser.add_argument("--hdim", type=parse_headdims, default=[(64, 64), (96, 96), (128, 128)], help="Head dims, comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128") parser.add_argument("--seqlen", type=csv_ints, default=[8192], help="Sequence lengths, comma-separated (default: 8192)") parser.add_argument("--batch", type=int, default=0, help="Batch size (0 = auto ~32k tokens)") parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)") parser.add_argument("--rep", type=int, default=30, help="Repetitions per benchmark (default: 30)") parser.add_argument("--causal-only", action="store_true") parser.add_argument("--non-causal-only", action="store_true") mode = parser.add_mutually_exclusive_group() mode.add_argument("--sweep-tiles", action="store_true", help="Sweep fwd tile sizes") mode.add_argument("--sweep-rs-overlap", action="store_true", help="Sweep fwd RS/overlap combos") mode.add_argument("--compare-configs", action="store_true", help="Compare old vs new fwd tile configs") mode.add_argument("--sweep-bwd-opts", action="store_true", help="Sweep bwd optimizations (V_in_regs, mma_dkv_is_rs, etc.)") args = parser.parse_args() torch.manual_seed(0) if args.sweep_tiles: run_sweep_tiles(args) elif args.sweep_rs_overlap: run_sweep_rs_overlap(args) elif args.compare_configs: run_compare_configs(args) elif args.sweep_bwd_opts: run_sweep_bwd_opts(args) else: run_default(args) if __name__ == "__main__": main() ================================================ FILE: benchmarks/benchmark_alibi.py ================================================ # Copyright (c) 2024, Sanghun Cho, Tri Dao. import pickle import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined from flash_attn import flash_attn_qkvpacked_func, flash_attn_func try: import xformers.ops as xops except ImportError: xops = None def generate_cos_sin(seqlen, rotary_dim, device, dtype): assert rotary_dim % 2 == 0 angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) return cos, sin def flash_rotary(q, k, v, cos, sin, causal=False): # corrected by @tridao comments q = apply_rotary_emb( q, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True ) k = apply_rotary_emb( k, cos, sin, seqlen_offsets=0, interleaved=False, inplace=True ) return flash_attn_func(q, k, v, causal=causal) def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False ): batch, nheads = slopes.shape device = slopes.device slopes = rearrange(slopes, "b h -> b h 1 1") if causal: return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes else: row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) relative_pos = torch.abs(row_idx + sk - sq - col_idx) return -slopes * relative_pos.to(dtype=slopes.dtype) def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): assert mode in ["fwd", "bwd", "fwd_bwd"] f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) def efficiency(flop, time): return (flop / time / 10**12) if not math.isnan(time) else 0.0 def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None): """ Arguments: q, k, v: (batch_size, seqlen, nheads, head_dim) dropout_p: float attn_bias: (batch_size, nheads, seqlen, seqlen) or (1, nheads, seqlen, seqlen) Output: output: (batch_size, seqlen, nheads, head_dim) """ batch_size, seqlen, nheads, d = q.shape q = rearrange(q, 'b t h d -> (b h) t d') k = rearrange(k, 'b s h d -> (b h) d s') softmax_scale = 1.0 / math.sqrt(d) # Preallocate attn_weights for `baddbmm` if attn_bias is not None: scores = rearrange(attn_bias, 'b h t s -> (b h) t s') else: scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=q.dtype, device=q.device) scores = rearrange(torch.baddbmm(scores, q, k, beta=1.0, alpha=softmax_scale), '(b h) t s -> b h t s', h=nheads) if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) output = torch.einsum('bhts,bshd->bthd', attention_drop , v) return output.to(dtype=q.dtype) def time_fwd_bwd(func, *args, **kwargs): time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) return time_f[1].mean, time_b[1].mean repeats = 30 device = 'cuda' dtype = torch.float16 bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] causal_vals = [False, True] headdim_vals = [64, 128] dim = 2048 dropout_p = 0.0 methods = (["fa2_alibi", "torch"] + (["xformers"] if xops is not None else []) + ["sdpa"] + ["fa2_baseline"] + ["fa2_rotary"]) time_f = {} time_b = {} time_f_b = {} speed_f = {} speed_b = {} speed_f_b = {} for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] # alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 alibi_slopes = torch.rand(1, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal).to(dtype) attn_bias = repeat(attn_bias, "1 ... -> b ...", b=batch_size) f, b = time_fwd_bwd( flash_attn_func, q, k, v, dropout_p, causal=causal, # alibi_slopes=alibi_slopes, alibi_slopes=None, repeats=repeats, verbose=False ) time_f[config, "fa2_baseline"] = f time_b[config, "fa2_baseline"] = b q = q.detach().requires_grad_(True) k = k.detach().requires_grad_(True) v = v.detach().requires_grad_(True) f, b = time_fwd_bwd( flash_attn_func, q, k, v, dropout_p, causal=causal, alibi_slopes=rearrange(alibi_slopes, "1 h -> h"), # alibi_slopes=None, repeats=repeats, verbose=False ) time_f[config, "fa2_alibi"] = f time_b[config, "fa2_alibi"] = b try: q = q.detach().requires_grad_(True) k = k.detach().requires_grad_(True) v = v.detach().requires_grad_(True) f, b = time_fwd_bwd( attention_pytorch, q, k, v, dropout_p, causal=causal, attn_bias=attn_bias, repeats=repeats, verbose=False ) except: # Skip if OOM f, b = float('nan'), float('nan') time_f[config, "torch"] = f time_b[config, "torch"] = b # F.sdpa doesn't currently (torch 2.1) dispatch to flash-attn but just to be safe with torch.backends.cuda.sdp_kernel(enable_flash=False): q_pt = q.detach().requires_grad_(True).transpose(1, 2) k_pt = k.detach().requires_grad_(True).transpose(1, 2) v_pt = v.detach().requires_grad_(True).transpose(1, 2) f, b = time_fwd_bwd( F.scaled_dot_product_attention, q_pt, k_pt, v_pt, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=causal, repeats=repeats, verbose=False ) time_f[config, "sdpa"] = f time_b[config, "sdpa"] = b if xops is not None: q = q.detach().requires_grad_(True) k = k.detach().requires_grad_(True) v = v.detach().requires_grad_(True) if causal: attn_bias_xops = xops.LowerTriangularMask().add_bias(attn_bias.expand(-1, -1, seqlen, -1).to(dtype=q.dtype)) # NotImplementedError: No operator found for `memory_efficient_attention_backward` with inputs: # `flshattB@v2.3.6` is not supported because: # attn_bias type is # `cutlassB` is not supported because: # attn_bias type is attn_bias_xops = attn_bias_xops.materialize((batch_size, nheads, seqlen, seqlen), dtype=q.dtype, device=device) else: attn_bias_xops = attn_bias.to(dtype=q.dtype) f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias_xops, dropout_p, repeats=repeats, verbose=False ) time_f[config, "xformers"] = f time_b[config, "xformers"] = b q = q.detach().requires_grad_(True) k = k.detach().requires_grad_(True) v = v.detach().requires_grad_(True) cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) f, b = time_fwd_bwd( flash_rotary, q, k, v, cos, sin, causal, repeats=repeats, verbose=False ) time_f[config, "fa2_rotary"] = f time_b[config, "fa2_rotary"] = b print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") csv_output = "" csv_output += f"{causal},{headdim},{batch_size},{seqlen}," for method in methods: time_f_b[config, method] = time_f[config, method] + time_b[config, method] speed_f[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), time_f[config, method] ) speed_b[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), time_b[config, method] ) speed_f_b[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), time_f_b[config, method] ) print( f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" ) csv_output += f"{speed_f[config, method]:.2f},{speed_b[config, method]:.2f},{speed_f_b[config, method]:.2f}," print(csv_output) ================================================ FILE: benchmarks/benchmark_attn.py ================================================ import argparse import time import torch try: import cudnn except ImportError: cudnn = None from einops import rearrange from flash_attn.cute.bench_utils import ( flops, attention_ref, cudnn_fwd_setup, cudnn_bwd_setup, ) try: from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func except ImportError: flash_attn_func = None flash_attn_varlen_func = None try: from flash_attn.cute.interface import flash_attn_func as flash_attn_func_python from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_python except ImportError: flash_attn_func_python = None flash_attn_varlen_func_python = None try: from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 except ImportError: flash_attn_func_v3 = None flash_attn_varlen_func_v3 = None if torch.cuda.get_device_capability()[0] != 9: flash_attn_func_v3 = None from triton.testing import do_bench # ── Autograd backward helper ──────────────────────────────────────────────── def _make_bwd_fn(fwd_fn, g, inputs): """Run fwd once, return a closure that benchmarks backward. Args: fwd_fn: zero-arg callable that runs the forward pass (with autograd). g: gradient tensor (b, seqlen, nheads, headdim_v). inputs: list of input tensors whose .grad should be cleared each iteration. """ out = fwd_fn() if isinstance(out, tuple): out = out[0] g_match = g[:out.shape[0]] if g.shape[0] != out.shape[0] else g # handle varlen def bwd_fn(): for x in inputs: x.grad = None out.backward(g_match, retain_graph=True) return bwd_fn # ── Backend definitions ───────────────────────────────────────────────────── # Each setup_* function takes a context dict and returns (fwd_fn, bwd_fn). # Either can be None if the backend doesn't support that direction for the # given config. fwd_fn / bwd_fn are zero-arg callables suitable for do_bench. def setup_standard(ctx): if ctx["dtype"] == torch.float8_e4m3fn: return None, None q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] fwd_fn = lambda: attention_ref(q, k, v, causal=causal) bwd_fn = _make_bwd_fn(fwd_fn, g, [q, k, v]) if ctx["has_backward"] else None return fwd_fn, bwd_fn def setup_fa2(ctx): if flash_attn_func is None or ctx["dtype"] == torch.float8_e4m3fn: return None, None if ctx["headdim"] != ctx["headdim_v"]: return None, None q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] dropout_p, window_size_fa, softcap = ctx["dropout_p"], ctx["window_size_fa"], ctx["softcap"] deterministic = ctx["deterministic"] if ctx["varlen"]: qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"] csq, csk, sq, sk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"], ctx["seqlen_q"], ctx["seqlen"] fwd_fn = lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap) bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func(qu, ku, vu, csq, csk, sq, sk, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) if ctx["has_backward"] else None else: fwd_fn = lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap) bwd_fn = _make_bwd_fn(lambda: flash_attn_func(q, k, v, dropout_p, causal=causal, window_size=window_size_fa, softcap=softcap, deterministic=deterministic), g, [q, k, v]) if ctx["has_backward"] else None return fwd_fn, bwd_fn def setup_cudnn(ctx): if cudnn is None or ctx["headdim"] > 256 or ctx["dtype"] == torch.float8_e4m3fn: return None, None q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] window_size_left = ctx["window_size"][0] # cuDNN expects (batch, nheads, seqlen, headdim) layout qt, kt, vt, gt = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), g.transpose(1, 2) fwd_fn, o_gpu, lse_gpu = cudnn_fwd_setup(qt, kt, vt, causal=causal, window_size_left=window_size_left) bwd_fn = None if ctx["has_backward"]: fwd_fn() # populate o and lse for bwd graph bwd_fn = cudnn_bwd_setup(qt, kt, vt, o_gpu, gt, lse_gpu, causal=causal, window_size_left=window_size_left) return fwd_fn, bwd_fn def setup_fa3(ctx): if flash_attn_func_v3 is None: return None, None q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] window_size_fa, softcap = ctx["window_size_fa"], ctx["softcap"] num_splits, pack_gqa, deterministic = ctx["num_splits"], ctx["pack_gqa"], ctx["deterministic"] k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v if ctx["varlen"]: qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"] csq, csk, sq, sk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"], ctx["seqlen_q"], ctx["seqlen"] fwd_fn = lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: fwd_fn = lambda: flash_attn_func_v3(q, k_use, v_use, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) # FA3 bwd only supports headdim == headdim_v and non-fp8 bwd_fn = None if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn and ctx["headdim"] == ctx["headdim_v"]: if ctx["varlen"]: bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_v3(qu, ku, vu, csq, csk, sq, sk, causal=causal, window_size=ctx["window_size"], softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) else: bwd_fn = _make_bwd_fn(lambda: flash_attn_func_v3(q, k, v, causal=causal, softcap=softcap), g, [q, k, v]) return fwd_fn, bwd_fn def setup_fa4(ctx): if flash_attn_func_python is None: return None, None q, k, v, g, causal = ctx["q"], ctx["k"], ctx["v"], ctx["g"], ctx["causal"] window_size, softcap = ctx["window_size"], ctx["softcap"] pack_gqa, deterministic = ctx["pack_gqa"], ctx["deterministic"] sinks = ctx["sinks"] k_use = ctx.get("k_paged", k) if ctx["page_size"] is not None else k v_use = ctx.get("v_paged", v) if ctx["page_size"] is not None else v if ctx["varlen"]: qu = ctx["q_unpad"] ku = ctx.get("k_paged", ctx["k_unpad"]) if ctx["page_size"] is not None else ctx["k_unpad"] vu = ctx.get("v_paged", ctx["v_unpad"]) if ctx["page_size"] is not None else ctx["v_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] pt = ctx["page_table"] fwd_fn = lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, page_table=pt, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa) else: fwd_fn = lambda: flash_attn_func_python(q, k_use, v_use, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa) bwd_fn = None if ctx["has_backward"] and ctx["dtype"] != torch.float8_e4m3fn: if ctx["varlen"]: qu, ku, vu = ctx["q_unpad"], ctx["k_unpad"], ctx["v_unpad"] csq, csk = ctx["cu_seqlens_q"], ctx["cu_seqlens_k"] bwd_fn = _make_bwd_fn(lambda: flash_attn_varlen_func_python(qu, ku, vu, csq, csk, causal=causal, softcap=softcap, deterministic=deterministic), g, [qu, ku, vu]) else: bwd_fn = _make_bwd_fn(lambda: flash_attn_func_python(q, k, v, causal=causal, softcap=softcap, deterministic=deterministic), g, [q, k, v]) return fwd_fn, bwd_fn # Ordered list of (display_name, cli_name, setup_fn) BACKENDS = [ ("Standard", "standard", setup_standard), ("FA2", "fa2", setup_fa2), ("cuDNN", "cudnn", setup_cudnn), ("FA3", "fa3", setup_fa3), ("FA4", "fa4", setup_fa4), ] def parse_int_k(s): """Parse an integer with optional k/K suffix, e.g. '8k' -> 8192.""" s = s.strip().lower() if s.endswith("k"): return int(s[:-1]) * 1024 return int(s) def csv_ints(s): """Parse comma-separated integers with optional k suffix, e.g. '512,1k,2k'.""" return [parse_int_k(x) for x in s.split(",")] def parse_headdims(s): """Parse comma-separated headdim specs. Each entry is hdim or hdim-hdim_v. Examples: '128' -> [(128, 128)] '192-128' -> [(192, 128)] '64,128,192' -> [(64, 64), (128, 128), (192, 192)] '64,128,192-128,192' -> [(64, 64), (128, 128), (192, 128), (192, 192)] """ result = [] for item in s.split(","): if "-" in item: parts = item.split("-") result.append((int(parts[0]), int(parts[1]))) else: hdim = int(item) result.append((hdim, hdim)) return result def csv_strs(s): """Parse comma-separated strings, e.g. 'fa3,fa4'.""" return [x.strip() for x in s.split(",")] def parse_args(): parser = argparse.ArgumentParser(description='Benchmark FlashAttention') parser.add_argument('--headdim', type=parse_headdims, default=[(128, 128)], help='Head dim(s), comma-separated. Each is hdim or hdim-hdim_v. E.g. 64,128,192-128') parser.add_argument('--fwd', action='store_true', help='Run forward only') parser.add_argument('--bwd', action='store_true', help='Run backward only') parser.add_argument('--varlen', action='store_true', default=False) parser.add_argument('--causal', type=str.lower, choices=['true', 'false', 'both'], default='both', help='Causal mode (default: both)') parser.add_argument('--seqlen', type=csv_ints, default=[8192], help='Sequence length(s), comma-separated. Supports k suffix, e.g. 1k,2k,8k') parser.add_argument('--total-seqlen', type=parse_int_k, default='32k', help='Total sequence length for batch sizing (default: 32k)') parser.add_argument('--batch-size', type=int, default=None, help='Batch size (default: total_seqlen // seqlen)') parser.add_argument('--deterministic', action='store_true', default=False) parser.add_argument('--nheads', type=int, default=None, help='Number of Q heads (default: 32 for hdim<=64, 16 for hdim<=192, 8 for hdim>192)') parser.add_argument('--nheads-kv', type=int, default=None, help='Number of KV heads (default: nheads)') parser.add_argument('--gqa-ratio', type=int, default=None, help='GQA ratio (nheads // nheads_kv). Ignored if --nheads-kv is set.') parser.add_argument('--backend', type=csv_strs, default=['all'], help='Which backends to benchmark, comma-separated (choices: all,standard,fa2,fa3,fa4,cudnn)') parser.add_argument('--warmup', type=int, default=5, help='Warmup iterations (default: 5)') parser.add_argument('--rep', type=int, default=10, help='Repetitions per benchmark (default: 10)') return parser.parse_args() def main(): args = parse_args() headdim_pairs = args.headdim # list of (hdim, hdim_v) tuples # Parse fwd/bwd: if neither specified, do fwd only has_forward = args.fwd or not args.bwd has_backward = args.bwd # Parse causal if args.causal == 'true': causal_vals = [True] elif args.causal == 'false': causal_vals = [False] else: causal_vals = [False, True] seqlen_list = args.seqlen varlen = args.varlen # Filter backends to those requested and available enabled = set(args.backend) if 'all' in enabled: enabled = {cli for _, cli, _ in BACKENDS} active_backends = [(name, cli, fn) for name, cli, fn in BACKENDS if cli in enabled] # Parameters torch.manual_seed(0) dropout_p = 0.0 dtype = torch.bfloat16 dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype device = 'cuda' page_size = None softcap = 0.0 deterministic = args.deterministic warmup, rep = args.warmup, args.rep time_f = {} time_b = {} for headdim, headdim_v in headdim_pairs: nheads = args.nheads if args.nheads is not None else (32 if headdim <= 64 else 16 if headdim <= 192 else 8) if args.nheads_kv is not None: nheads_kv = args.nheads_kv elif args.gqa_ratio is not None: nheads_kv = nheads // args.gqa_ratio else: nheads_kv = nheads has_qv = headdim == 64 and headdim_v == 512 sinks = None num_splits = 0 window_size = (None, None) window_size_fa = (-1, -1) pack_gqa = None for seqlen in seqlen_list: batch_size = args.batch_size if args.batch_size is not None else max(1, args.total_seqlen // seqlen) seqlen_q = seqlen q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=has_backward) v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=has_backward) q, k, v = [x.detach().to(dtype).requires_grad_(has_backward) for x in [q, k, v]] g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) # Varlen tensors q_unpad = k_unpad = v_unpad = cu_seqlens_q = cu_seqlens_k = None if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_(has_backward) for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen if page_size is None else None # Paged KV tensors k_paged = v_paged = page_table = None if page_size is not None: assert seqlen % page_size == 0 k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), "(b s) -> b s", s=seqlen // page_size) for causal in causal_vals: cfg = (headdim, headdim_v, causal, seqlen, batch_size, nheads) # Build context dict shared by all backends ctx = dict( q=q, k=k, v=v, g=g, causal=causal, headdim=headdim, headdim_v=headdim_v, dtype=dtype, has_backward=has_backward, varlen=varlen, q_unpad=q_unpad, k_unpad=k_unpad, v_unpad=v_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqlen_q=seqlen_q, seqlen=seqlen, page_size=page_size, k_paged=k_paged, v_paged=v_paged, page_table=page_table, dropout_p=dropout_p, window_size=window_size, window_size_fa=window_size_fa, softcap=softcap, deterministic=deterministic, num_splits=num_splits, pack_gqa=pack_gqa, sinks=sinks, ) for display_name, cli_name, setup_fn in active_backends: fwd_fn, bwd_fn = setup_fn(ctx) if fwd_fn is not None and has_forward: time.sleep(1.0) print(f"Benchmarking {display_name} fwd, hdim={headdim}, seqlen={seqlen}, causal={causal}") ms = do_bench(fwd_fn, warmup=warmup, rep=rep) * 1e-3 time_f[cfg, display_name] = ms if bwd_fn is not None and has_backward: time.sleep(1.0) print(f"Benchmarking {display_name} bwd, hdim={headdim}, seqlen={seqlen}, causal={causal}") ms = do_bench(bwd_fn, warmup=warmup, rep=rep) * 1e-3 time_b[cfg, display_name] = ms # ── Print results table ────────────────────────────────────────────────── backend_names = [name for name, _, _ in BACKENDS] shown_backends = [b for b in backend_names if any(b == k[1] for k in list(time_f) + list(time_b))] if not shown_backends: return col_w = 16 for direction, times, flops_mult in [("FWD", time_f, 1.0), ("BWD", time_b, 2.5)]: if not times: continue configs = sorted(set(k[0] for k in times)) if not configs: continue header = f"{'hdim':>9} {'causal':>6} {'batch':>5} {'seqlen':>6}" for b in shown_backends: header += f" {b:>{col_w}}" print(f"\n{'=' * len(header)}") print(f" {direction} (ms / TFLOPS)") print(f"{'=' * len(header)}") print(header) print("-" * len(header)) for cfg in configs: headdim, headdim_v, causal, seqlen, batch_size, nheads = cfg nFLOPS = flops(batch_size, nheads, seqlen, seqlen, headdim, headdim_v, causal=causal) hdim_str = str(headdim) if headdim == headdim_v else f"{headdim}-{headdim_v}" row = f"{hdim_str:>9} {str(causal):>6} {batch_size:>5} {seqlen:>6}" for b in shown_backends: t = times.get((cfg, b)) if t is not None: tflops = flops_mult * nFLOPS / t * 1e-12 ms = t * 1e3 cell = f"{ms:.2f}/{tflops:.0f}" row += f" {cell:>{col_w}}" else: row += f" {'—':>{col_w}}" print(row) if __name__ == '__main__': main() ================================================ FILE: benchmarks/benchmark_causal.py ================================================ from functools import partial import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func # # from flash_attn.triton.fused_attention import attention as attention # from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func # from flash_attn.flash_attn_triton_og import attention as attention_og # from triton.ops.flash_attention import attention as attention_triton from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) dropout_p: float Output: output: (batch_size, seqlen, nheads, head_dim) """ batch_size, seqlen, _, nheads, d = qkv.shape q, k, v = qkv.unbind(dim=2) q = rearrange(q, 'b t h d -> (b h) t d') k = rearrange(k, 'b s h d -> (b h) d s') softmax_scale = 1.0 / math.sqrt(d) # Preallocate attn_weights for `baddbmm` scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), '(b h) t s -> b h t s', h=nheads) if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) output = torch.einsum('bhts,bshd->bthd', attention_drop , v) return output.to(dtype=qkv.dtype) torch.manual_seed(0) repeats = 30 batch_size = 8 seqlen = 2048 nheads = 12 headdim = 128 # nheads = 24 # headdim = 64 # batch_size = 64 # seqlen = 512 # nheads = 8 # headdim = 128 dropout_p = 0.0 causal = True dtype = torch.float16 device = 'cuda' qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad, # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad, # cu_seqlens, seqlen, dropout_p, causal=causal, backward=True) benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False) # for dropout_p in [0.1, 0.0]: # for causal in [False, True]: # print(f"### {dropout_p = }, {causal = } ###") # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) # nheads_k = 2 # q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) # kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype, # requires_grad=True) # if fav2_kvpacked_func is not None: # benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') # pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True) # dropout_p = 0.0 # causal = False # benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, # repeats=repeats, desc='PyTorch Attention') # benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton') # pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True) # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, # requires_grad=True) for _ in range(3)] # benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') # # pytorch_profiler(attention, q, k, v, 1.0, backward=True) # from src.ops.fftconv import fftconv_func # dim = nheads * headdim # u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True) # k = torch.randn(dim, seqlen, device=device, requires_grad=True) # D = torch.randn(dim, device=device, requires_grad=True) # benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv') # pytorch_profiler(fftconv_func, u, k, D, backward=True) # pytorch_profiler(torch.fft.rfft, u.float()) flops = 4 * batch_size * seqlen ** 2 * nheads * headdim ideal_a100_time = flops / 312 / 1e9 print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms") exit(0) def time_fwd_bwd(func, *args, **kwargs): time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) return time_f[1].mean, time_b[1].mean bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] causal_vals = [False, True] headdim_vals = [64, 128] dim = 2048 dropout_p = 0.0 time_f = {} time_b = {} for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: nheads = dim // headdim qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) f, b = time_fwd_bwd( flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, verbose=False ) time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b qkv = qkv.detach().requires_grad_(True) f, b = time_fwd_bwd( fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False ) time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b # q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, # requires_grad=True) for _ in range(3)] # # Try both values of sequence_parallel and pick the faster one # f, b = time_fwd_bwd( # attention_triton, q, k, v, causal, headdim**(-0.5), # False, repeats=repeats, verbose=False # ) # _, b0 = time_fwd_bwd( # attention_triton, q, k, v, causal, headdim**(-0.5), # True, repeats=repeats, verbose=False # ) # time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0) if seqlen <= 8 * 1024: qkv = qkv.detach().requires_grad_(True) f, b = time_fwd_bwd( attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False ) else: f, b = float('nan'), float('nan') time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b # q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, # requires_grad=True) for _ in range(3)] # import xformers.ops as xops # f, b = time_fwd_bwd( # xops.memory_efficient_attention, q, k, v, # attn_bias=xops.LowerTriangularMask() if causal else None, # op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) # ) # time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f # time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b import pickle with open('flash2_attn_time_h100.plk', 'wb') as fp: pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) ================================================ FILE: benchmarks/benchmark_flash_attention.py ================================================ # Install the newest triton version with # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" import pickle import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined from flash_attn import flash_attn_qkvpacked_func try: from triton.ops.flash_attention import attention as attention_triton except ImportError: attention_triton = None try: import xformers.ops as xops except ImportError: xops = None def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): assert mode in ["fwd", "bwd", "fwd_bwd"] f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) def efficiency(flop, time): return (flop / time / 10**12) if not math.isnan(time) else 0.0 def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) dropout_p: float Output: output: (batch_size, seqlen, nheads, head_dim) """ batch_size, seqlen, _, nheads, d = qkv.shape q, k, v = qkv.unbind(dim=2) q = rearrange(q, 'b t h d -> (b h) t d') k = rearrange(k, 'b s h d -> (b h) d s') softmax_scale = 1.0 / math.sqrt(d) # Preallocate attn_weights for `baddbmm` scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), '(b h) t s -> b h t s', h=nheads) if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) # Adding is faster than masked_fill_ scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) output = torch.einsum('bhts,bshd->bthd', attention_drop , v) return output.to(dtype=qkv.dtype) def time_fwd_bwd(func, *args, **kwargs): time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs) return time_f[1].mean, time_b[1].mean repeats = 30 device = 'cuda' dtype = torch.float16 bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] causal_vals = [False, True] headdim_vals = [64, 128] dim = 2048 dropout_p = 0.0 methods = (["Flash2", "Pytorch"] + (["Triton"] if attention_triton is not None else []) + (["xformers.c"] if xops is not None else []) + (["xformers.f"] if xops is not None else [])) time_f = {} time_b = {} time_f_b = {} speed_f = {} speed_b = {} speed_f_b = {} for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim # FlashAttention 2 if "Flash2" in methods: qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) f, b = time_fwd_bwd( flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False ) time_f[config, "Flash2"] = f time_b[config, "Flash2"] = b # PyTorch baseline if "Pytorch" in methods: try: # fresh tensor avoids grad-history reuse issues qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True) f, b = time_fwd_bwd( attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False ) except Exception: f, b = float('nan'), float('nan') time_f[config, "Pytorch"] = f time_b[config, "Pytorch"] = b # Triton if "Triton" in methods and attention_triton is not None: q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] # Try both values of sequence_parallel and pick the faster backward try: f, b = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), False, repeats=repeats, verbose=False ) except Exception: f, b = float('nan'), float('inf') try: _, b0 = time_fwd_bwd( attention_triton, q, k, v, causal, headdim**(-0.5), True, repeats=repeats, verbose=False ) except Exception: b0 = float('inf') time_f[config, "Triton"] = f time_b[config, "Triton"] = min(b, b0) if min(b, b0) < float('inf') else float('nan') # xFormers CUTLASS if "xformers.c" in methods and xops is not None: q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp) ) time_f[config, "xformers.c"] = f time_b[config, "xformers.c"] = b # xFormers Flash if "xformers.f" in methods and xops is not None: q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] f, b = time_fwd_bwd( xops.memory_efficient_attention, q, k, v, attn_bias=xops.LowerTriangularMask() if causal else None, op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp) ) time_f[config, "xformers.f"] = f time_b[config, "xformers.f"] = b # Report print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") for method in methods: if (config, method) not in time_f or (config, method) not in time_b: continue time_f_b[config, method] = time_f[config, method] + time_b[config, method] speed_f[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), time_f[config, method] ) speed_b[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="bwd"), time_b[config, method] ) speed_f_b[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd_bwd"), time_f_b[config, method] ) print( f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, " f"bwd: {speed_b[config, method]:.2f} TFLOPs/s, " f"fwd + bwd: {speed_f_b[config, method]:.2f} TFLOPs/s" ) # with open('flash2_attn_time.plk', 'wb') as fp: # pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) ================================================ FILE: benchmarks/benchmark_gemm.py ================================================ import time import torch import torch.utils.benchmark as benchmark from triton.testing import do_bench if torch.version.cuda: backendBLAS = "cuBLAS" elif torch.version.hip: backendBLAS = "hipBLAS" def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" if verbose: print(desc, '- Forward pass') t = benchmark.Timer( stmt='fn(*inputs, **kwinputs)', globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m torch.manual_seed(0) repeats = 30 dtype = torch.bfloat16 device = 'cuda' verbose = False m, n = 8192, 8192 tflops_matmul = {} tflops_matmul1 = {} for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]: a = torch.randn(m, k, device=device, dtype=dtype) b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) nFLOPS_matmul = 2 * m * n * k time.sleep(2) # to reduce power throttling timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1] tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12 print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS') time.sleep(2) # to reduce power throttling ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats) tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9 print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS') ================================================ FILE: csrc/flash_attn/flash_api.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. #include #include #include #include #include // For at::Generator and at::PhiloxCudaState #include "philox_unpack.cuh" // For at::cuda::philox::unpack #include #include "namespace_config.h" #include "hardware_info.h" #include "flash.h" #include "static_switch.h" #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace FLASH_NAMESPACE { void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_k, void *p_d, void *softmax_lse_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, const float softcap, bool seqlenq_ngroups_swapped=false, const bool unpadded_lse=false) { // Reset the parameters params = {}; params.is_bf16 = q.dtype() == torch::kBFloat16; // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); params.o_ptr = out.data_ptr(); params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; params.o_batch_stride *= seqlen_q; } } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); params.seqused_k = static_cast(seqused_k); // P = softmax(QK^T) params.p_ptr = p_d; // Softmax sum params.softmax_lse_ptr = softmax_lse_d; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.h_h_k_ratio = h / h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; // Set the different scale values. #ifdef FLASHATTENTION_DISABLE_SOFTCAP TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); #endif if (softcap > 0.0) { params.softcap = softmax_scale / softcap; params.scale_softmax = softcap; params.scale_softmax_log2 = softcap * M_LOG2E; } else{ // Remove potential NaN params.softcap = 0.0; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; } // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; // Convert p from float to int so we don't have to convert the random uint to float to compare. // [Minor] We want to round down since when we do the comparison we use <= instead of < // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0; if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } params.window_size_left = window_size_left; params.window_size_right = window_size_right; #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), "This flash attention build does not support local attention."); #endif params.is_seqlens_k_cumulative = true; #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif params.unpadded_lse = unpadded_lse; params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } void set_params_dgrad(Flash_bwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor out, const at::Tensor dout, at::Tensor dq, at::Tensor dk, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, void *softmax_lse_d, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, const float softcap, bool deterministic, const bool unpadded_lse) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, nullptr, nullptr, softmax_lse_d, p_dropout, softmax_scale, window_size_left, window_size_right, softcap, false, // seqlenq_ngroups_swapped unpadded_lse); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); params.do_row_stride = dout.stride(-3); params.do_head_stride = dout.stride(-2); params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); params.dq_row_stride = dq.stride(-3); params.dk_row_stride = dk.stride(-3); params.dv_row_stride = dv.stride(-3); params.dq_head_stride = dq.stride(-2); params.dk_head_stride = dk.stride(-2); params.dv_head_stride = dv.stride(-2); if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); } params.dq_accum_ptr = dq_accum_d; params.dk_accum_ptr = dk_accum_d; params.dv_accum_ptr = dv_accum_d; // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; params.deterministic = deterministic; } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { run_mha_fwd_splitkv_dispatch(params, stream); } }); }); }); } // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { continue; } if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } std::tuple set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, const float p_dropout, const int num_splits, const int num_sm, struct c10::TensorOptions opts) { // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; params.num_splits = num_splits; at::Tensor softmax_lse_accum; at::Tensor out_accum; if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); } if (params.num_splits > 1) { softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_ptr = out_accum.data_ptr(); } TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } return std::make_tuple(softmax_lse_accum, out_accum); } void set_params_alibi(Flash_fwd_params ¶ms, std::optional &alibi_slopes_, int batch_size, int num_heads){ #ifdef FLASHATTENTION_DISABLE_ALIBI TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); params.alibi_slopes_ptr = nullptr; #else if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32"); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads})); params.alibi_slopes_ptr = alibi_slopes.data_ptr(); params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } else { params.alibi_slopes_ptr = nullptr; } #endif } std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); } } else { out = torch::empty_like(q); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); auto opts = q.options(); auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); } else { p = torch::empty({ 0 }, opts); } Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), p_dropout, softmax_scale, window_size_left, window_size_right, softcap ); // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // Forward kernel will populate memory with the seed and offset. params.rng_state = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse, p, rng_state}; } std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &out_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool return_softmax, std::optional gen_) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); at::Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { block_table = block_table_.value(); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); max_seqlen_q = ngroups; num_heads = num_heads_k; cu_seqlens_q_d = nullptr; } const int total_q = q.sizes()[0]; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, total_q, num_heads, head_size); if (!paged_KV) { const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); } else { CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); CHECK_SHAPE(seqused_k_, batch_size); } at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, sizes[0], sizes[1], head_size); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); } } else { out = torch::empty_like(q); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); } else { p = torch::empty({ 0 }, opts); } if (zero_tensors) { out.zero_(); softmax_lse.fill_(-std::numeric_limits::infinity()); if (return_softmax) {p.zero_();} } Flash_fwd_params params; set_params_fprop(params, batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), p_dropout, softmax_scale, window_size_left, window_size_right, softcap, seqlenq_ngroups_swapped, /*unpadded_lse*/true); params.total_q = total_q; if (paged_KV) { params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); } params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; if (seqlenq_ngroups_swapped) { // Only apply split-k for decoding std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); params.leftpad_k = static_cast(leftpad_k.data_ptr()); } // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); // Forward kernel will populate memory with the seed and offset. params.rng_state = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream, paged_KV); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } if (seqlenq_ngroups_swapped) { int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size}; int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size}; out = out.reshape(size_before).transpose(1, 2).reshape(size_after); q = q.reshape(size_before).transpose(1, 2).reshape(size_after); softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } return {out, softmax_lse, p, rng_state}; } void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_mha_bwd_(params, stream); }); }); }); } std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool deterministic, std::optional gen_, std::optional &rng_state) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { dq = torch::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { dk = torch::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { dv = torch::empty_like(v); } // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); at::Tensor dq_accum; at::Tensor dk_accum, dv_accum; if (loop) { if (!deterministic) { dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); } else { dk_expanded = dk; dv_expanded = dv; } Flash_bwd_params params; set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, dout, dq, dk_expanded, dv_expanded, nullptr, nullptr, loop ? dq_accum.data_ptr() : nullptr, // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, nullptr, nullptr, softmax_lse.data_ptr(), softmax_d.data_ptr(), p_dropout, softmax_scale, window_size_left, window_size_right, softcap, deterministic, /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); } else if( is_dropout ) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); auto seeds = at::cuda::philox::unpack(params.philox_args); params.rng_state[0] = std::get<0>(seeds); params.rng_state[1] = std::get<1>(seeds); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (seqlen_q > 0) { launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } return { dq, dk, dv, softmax_d }; } std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool deterministic, std::optional gen_, std::optional &rng_state) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; const int num_heads = sizes[1]; const int head_size = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(dout, total_q, num_heads, head_size); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, total_q, num_heads, head_size); } else { dq = torch::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } else { dk = torch::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { dv = torch::empty_like(v); } // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; auto opts = q.options(); auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) // because that would be too large if there is a very long sequence and the rest of the sequences are short. // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). // Note that 128 is the max block size on the seqlen_q dimension. // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. // Same holds for softmax_d, since LSE is stored in unpadded format. if (!deterministic) { dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads); dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } } at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); } else { dk_expanded = dk; dv_expanded = dv; } if( zero_tensors ) { dq.zero_(); dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } Flash_bwd_params params; set_params_dgrad(params, batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, dout, dq, dk_expanded, dv_expanded, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), loop ? dq_accum.data_ptr() : nullptr, nullptr, nullptr, softmax_lse.data_ptr(), softmax_d.data_ptr(), p_dropout, softmax_scale, window_size_left, window_size_right, softcap, deterministic, /*unpadded_lse*/true); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); params.total_q = total_q; auto launch = &run_mha_bwd; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; if ( rng_state.has_value() ) { params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); } else if( is_dropout ) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); auto seeds = at::cuda::philox::unpack(params.philox_args); params.rng_state[0] = std::get<0>(seeds); params.rng_state[1] = std::get<1>(seeds); } set_params_alibi(params, alibi_slopes_, batch_size, num_heads); if (max_seqlen_q > 0) { launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } return { dq, dk, dv, softmax_d }; } std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &seqlens_k_, // batch_size std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional &cache_batch_idx_, // indices to index into the KV cache std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or batch_size x num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits ) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); at::Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); block_table = block_table_.value(); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); if (!paged_KV) { CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); } else { CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; kcache_padded = kcache; vcache_padded = vcache; } at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, 8); const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); auto opts = q.options(); auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q_padded, kcache_padded, vcache_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, /*p_d=*/nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, softcap ); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); k = k_.value(); v = v_.value(); TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); int seqlen_knew = k.size(1); CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); if (head_size_og % 8 != 0) { k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { k_padded = k; v_padded = v; } params.seqlen_knew = seqlen_knew; params.knew_ptr = k_padded.data_ptr(); params.vnew_ptr = v_padded.data_ptr(); // All stride are in elements, not bytes. params.knew_batch_stride = k_padded.stride(0); params.vnew_batch_stride = v_padded.stride(0); params.knew_row_stride = k_padded.stride(-3); params.vnew_row_stride = v_padded.stride(-3); params.knew_head_stride = k_padded.stride(-2); params.vnew_head_stride = v_padded.stride(-2); } if (seqlens_k_.has_value()) { auto seqlens_k = seqlens_k_.value(); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); params.cu_seqlens_k = static_cast(seqlens_k.data_ptr()); } params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); if (leftpad_k_.has_value()) { TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); params.leftpad_k = static_cast(leftpad_k.data_ptr()); } if (rotary_cos_.has_value()) { TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2; TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); CHECK_CONTIGUOUS(rotary_cos); TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); CHECK_CONTIGUOUS(rotary_sin); TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; } else { params.rotary_dim = 0; } if (cache_batch_idx_.has_value()) { auto cache_batch_idx = cache_batch_idx_.value(); CHECK_DEVICE(cache_batch_idx); CHECK_CONTIGUOUS(cache_batch_idx); TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts); if (paged_KV) { params.block_table = block_table.data_ptr(); params.block_table_batch_stride = block_table.stride(0); } params.page_block_size = page_block_size; set_params_alibi(params, alibi_slopes_, batch_size, num_heads); auto stream = at::cuda::getCurrentCUDAStream().stream(); // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, // or paged KV cache run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); if (head_size_og % 8 != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); if (out_.has_value()) { out_.value().copy_(out); } if (k_.has_value()) { // It's expensive to copy the KV cache here for the case where head size not divisible by 8, // but we don't expect to get this case in practice. This is just so that the code works for that case. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); } } if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse}; } } // namespace FLASH_NAMESPACE PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)"); m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)"); m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache"); } ================================================ FILE: csrc/flash_attn/src/alibi.h ================================================ #include #include "namespace_config.h" #include #include #include #include "utils.h" namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Alibi { const float alibi_slope; const int max_seqlen_k, max_seqlen_q; __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) : alibi_slope(alibi_slope) , max_seqlen_k(max_seqlen_k) , max_seqlen_q(max_seqlen_q) { }; template __forceinline__ __device__ void apply_alibi(Tensor &tensor, const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; } } } } else { // Bias depends on both row_idx and col_idx #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); } } } } } } }; } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/block_info.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BlockInfo { template __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } template __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; } const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. const int leftpad_k; const int seqlen_k_cache; const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/dropout.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include "philox.cuh" #include "utils.h" namespace FLASH_NAMESPACE { struct Dropout { const unsigned long long seed, offset; const uint8_t p_dropout_in_uint8_t; __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, const uint8_t p_dropout_in_uint8_t, const int bid, const int hid, const int tid, const int nheads) : seed(seed) , offset(offset + (bid * nheads + hid) * 32 + tid % 32) , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { } template __forceinline__ __device__ void apply_dropout(Tensor &tensor_, int block_row_start, int block_col_start, int block_row_stride) { // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); }; static_assert(decltype(size<2>(tensor))::value % 2 == 0); const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } #pragma unroll for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { uint2 rowcol = make_uint2(block_row_start, block_col_start); #pragma unroll for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast(rowcol), offset); // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); // Special implementation for 16-bit types: we duplicate the threshold to the // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, // and the high 16 bits will be either 0xffff or 0x0000, depending on whether // the random value is less than the threshold. // We then do a bit-wise AND between the mask and the original value (in 32-bit). // We're exploiting the fact that floating point comparison is equivalent to integer // comparison, since we're comparing unsigned integers whose top 8-bits are zero. if (!encode_dropout_in_sign_bit && (std::is_same::value || std::is_same::value)) { uint16_t rnd_16[16]; #pragma unroll for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); #pragma unroll for (int j = 0; j < 2; j++) { Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } #pragma unroll for (int i = 0; i < 4; i++) { uint32_t mask; asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); tensor_uint32(i) &= mask; } // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } } } else { #pragma unroll for (int j = 0; j < 2; j++) { #pragma unroll for (int i = 0; i < 8; i++) { tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); } Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } } } // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); // // } } } } }; } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include #include #include // For at::Generator and at::PhiloxCudaState namespace FLASH_NAMESPACE { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ v_ptr; // The stride between rows of the Q, K and V matrices. index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride; index_t q_row_stride; index_t k_row_stride; index_t v_row_stride; index_t q_head_stride; index_t k_head_stride; index_t v_head_stride; // The number of heads. int h, h_k; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). int h_h_k_ratio; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; index_t o_row_stride; index_t o_head_stride; // The pointer to the P matrix. void * __restrict__ p_ptr; // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lseaccum_ptr; // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; int *__restrict__ blockmask; // The K_new and V_new matrices. void * __restrict__ knew_ptr; void * __restrict__ vnew_ptr; // The stride between rows of the Q, K and V matrices. index_t knew_batch_stride; index_t vnew_batch_stride; index_t knew_row_stride; index_t vnew_row_stride; index_t knew_head_stride; index_t vnew_head_stride; // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. int * __restrict__ cache_batch_idx; // Paged KV cache int * __restrict__ block_table; index_t block_table_batch_stride; int page_block_size; // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; // uint16_t p_dropout_in_uint16_t; uint8_t p_dropout_in_uint8_t; // Scale factor of 1 / (1 - p_dropout). float rp_dropout; float scale_softmax_rp_dropout; // Local window size int window_size_left, window_size_right; float softcap; // Random state. at::PhiloxCudaState philox_args; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; bool is_bf16; bool is_causal; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative; bool is_rotary_interleaved; int num_splits; // For split-KV version void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_bwd_params : public Flash_fwd_params { // The dO and dQKV matrices. void *__restrict__ do_ptr; void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ // dv_accum_ptr; // The stride between rows of the dO, dQ, dK and dV matrices. // TD [2022-04-16]: We're using 32-bit indexing to save registers. // The code probably won't work for arrays larger than 2GB. index_t do_batch_stride; index_t do_row_stride; index_t do_head_stride; index_t dq_batch_stride; index_t dk_batch_stride; index_t dv_batch_stride; index_t dq_row_stride; index_t dk_row_stride; index_t dv_row_stride; index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; bool deterministic; index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_bwd_kernel.h ================================================ /*************************************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include #include #include #include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" #include "mask.h" #include "dropout.h" #include "alibi.h" namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_HOST_DEVICE auto make_tiled_copy_B_warpcontiguousN(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; // This gives the correct layout, idk why. // auto t = make_tile(Layout, _2>, // Stride, _8> >{}, // auto t = make_tile(Layout, // Stride<_1, _64, _8> >{}, auto t = make_tile(Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}, // (1, 64, 8) or (1, 32, 8) make_layout(Int{})); // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value; constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; auto t = make_tile(make_layout(Int{}), Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); if (Is_local) { m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t dq_accum_batch_stride = static_cast(params.seqlen_q_rounded) * params.h * params.d_rounded; const index_t dq_accum_row_stride = static_cast(params.h) * params.d_rounded; const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb) + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQdO{}); Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); // Double buffer for sQ Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{}); Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK), typename Kernel_traits::SmemLayoutPdS{}); Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); // sP and sdQ share the same memory so be careful Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); using GmemTiledCopydO = std::conditional_t< Is_first, typename Kernel_traits::GmemTiledCopydO, typename Kernel_traits::GmemTiledCopyQKV >; GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); using GmemLayoutAtomdQaccum = std::conditional_t< !Seq_parallel, typename Kernel_traits::GmemTiledCopydQaccum, typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd >; GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); } // __syncthreads(); // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) { // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data()); // } typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) typename Kernel_traits::TiledMmadKV tiled_mma_dkv; auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) typename Kernel_traits::TiledMmadQ tiled_mma_dq; auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K // // Copy Atom retiling // auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_KV.partition_S(sK); // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); } Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); // Partition sP and sdS to match the accumulator partitioning // This has to be tiled_mma_sdp, not tiled_mma_dkv // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx); auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); } // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); } // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) { // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data()); // } Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) // // PREDICATES // Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ); Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue // We'll advance gdQ and gdQaccum before the 1st read/write. tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; int m_block_min = (!Is_causal && !Is_local) ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); // If not local, we're guaranteed that m_block_min <= m_block: // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. // However, if local, then this possible to have some blocks of K & V not attending to any query. // We might need to exit early and write 0 to dK and dV for those blocks. // Otherwise we get wrong result for the case where we don't enter the for loop. // And we might read OOB elements from gQ and gdO. // This also covers the case where actual_seqlen_q == 0 if ((Is_local || !Is_even_MN) && m_block < m_block_min) { const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdKrdK = make_tensor(shape(tdKgdK)); Tensor tdVrdV = make_tensor(shape(tdVgdV)); clear(tdKrdK); clear(tdVrdV); Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); return; } if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ tQsQ.data() = tQsQ.data() + size(sQ); tSsQ.data() = tSsQ.data() + size(sQ); tdKsQt.data() = tdKsQt.data() + size(sQ); } if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } if (Kernel_traits::Is_V_in_regs) { // Clear the smem tiles to account for predicated off loads FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::cp_async_fence(); } Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); if (!Is_first) { // Clear the smem tiles to account for predicated off loads FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); } else { FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); } FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM ); Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) static_assert(decltype(size<0>(taccScS))::value == 4); // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); Tensor lse = make_tensor(Shape>{}); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply // with V (which would be zero), we're fine. However, with ALiBi, we might modify these // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. // Tensor tKrK = make_fragment_like(tKsK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK); // // if (cute::thread(1, 0)) { print(tKrK); } FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); if (!Kernel_traits::Is_V_in_regs) { FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } FLASH_NAMESPACE::cp_async_fence(); // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); } if (Is_first) { cute::copy(tdOrdO, tdOsdO); dot_do_o(tdOrdO, tdOrO, gdPsum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); } if (Kernel_traits::Is_V_in_regs) { cute::cp_async_wait<1>(); __syncthreads(); Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); } FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); clear(acc_dv); clear(acc_dk); const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) clear(acc_s); cute::cp_async_wait<0>(); __syncthreads(); Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } // if (cute::thread0()) { print(sK); } // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK); // #pragma unroll // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) { // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // } // if (cute::thread0()) { print(tSrK); } FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); if constexpr (Is_softcap) { FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } // Softcapping - calculating dTanh and scaling dS later with it [[maybe_unused]] Tensor dtanh = make_tensor_like(scores); if constexpr (Is_softcap) { FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } // Alibi if (Has_alibi) { alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); } // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // actual_seqlen_k, because acc_s would be some finite value for those indices. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, // so the result would still be correct. // However, it's possible that the values in acc_s are so large that they overflow // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. // So we need to mask out the elements beyond actual_seqlen_k. if (!Is_causal && !Is_local) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); } } else if (Is_causal) { // Putting this causal masking right after acc_s is *much* slower for some reason. // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. // But we still want to mask out elements beyond actual_seqlen_k. if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } } else if (Is_local) { if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), binfo.actual_seqlen_q, AtomLayoutMS * 16, params.window_size_left, params.window_size_right); } } // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. FLASH_NAMESPACE::scale_apply_exp2(scores, lse, params.scale_softmax_log2); if constexpr (Is_dropout) { int warp_id = tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); dropout.template apply_dropout( acc_s, block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 Tensor rP = !Is_dropout ? FLASH_NAMESPACE::convert_type(acc_s) : FLASH_NAMESPACE::convert_type_relu(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } // __syncthreads(); // if (cute::thread0()) { print(sP); } Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA clear(acc_dp); // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout())); // #pragma unroll // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { // #pragma unroll // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) { // acc_dp_reshaped(mi, ni) = -dP_sum(mi); // } // } // if (cute::thread0()) { print(dP_sum); } FLASH_NAMESPACE::gemm( acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); auto pointwise_mult = [](float p, float dp, float d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); }; #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; } } // if (cute::thread0()) { print(dS); } Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); if (Is_first || Seq_parallel) { clear(acc_dq); } else { // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout()))); cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped); } if (Double_buffer && m_block > m_block_min) { // Double buffer for sQ const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ); tQsQ.data() = tQsQ.data() + sQ_offset; tSsQ.data() = tSsQ.data() + sQ_offset; // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); FLASH_NAMESPACE::cp_async_fence(); } Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); // Convert dS from fp32 to fp16 Tensor tdSrdS = FLASH_NAMESPACE::convert_type(dS_reshaped); // if (cute::thread0()) { print(tPrP); } Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); __syncthreads(); // Layout p_l = tPrP.layout(); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } // if (cute::thread0()) { print(acc_dv); } __syncthreads(); // Need syncthreads since we're writing to the same sdO location if (m_block > m_block_min) { // Advance gdO tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride)); if (Is_first) { tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride)); FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ); FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ); } else { FLASH_NAMESPACE::copy(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ); FLASH_NAMESPACE::cp_async_fence(); } } FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); // if (cute::thread0()) { print(acc_dq); } if (m_block > m_block_min) { gLSE.data() = gLSE.data() + (-int(kBlockM)); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); } gdPsum.data() = gdPsum.data() + (-int(kBlockM)); } if (!Is_last) { // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum Tensor acc_dq_reshaped = make_tensor(acc_dq.data(), make_layout(get<0>(acc_dq.layout()), get<2>(acc_dq.layout()), get<1>(acc_dq.layout()))); if (!Seq_parallel) { cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum); } else { // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); } CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } } } else { #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0()) { print(acc_dk); } if (Double_buffer) { // Double buffer for sQ tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ)); } if (!Double_buffer && m_block > m_block_min) { __syncthreads(); // Advance gQ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride)); FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ); FLASH_NAMESPACE::cp_async_fence(); } if (Is_first && m_block > m_block_min) { cute::copy(tdOrdO, tdOsdO); dot_do_o(tdOrdO, tdOrO, gdPsum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); } if (Is_last) { __syncthreads(); Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride)); Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); #pragma unroll for (int m = 0; m < size<1>(tdQgdQ); ++m) { if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); } } } } // Epilogue if (Is_dropout) { #pragma unroll for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; } } #pragma unroll for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dv from fp32 to fp16 Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) // Partition sdV and sdK to match the accumulator partitioning auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) // We need syncthreads here since we're writing to the same location as sK and sV. // Without syncthreads, some thread might modify the location of sK while another thread // is reading it for dQ gemm, leading to a race condition. // If Is_last, there's already a __syncthreads() at the end of the loop. if (!Is_last) { __syncthreads(); } cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); __syncthreads(); Tensor tdKrdK = make_tensor(shape(tdKgdK)); cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); Tensor tdVrdV = make_tensor(shape(tdVgdV)); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. const int bidb = blockIdx.x; // const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.y; // const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash ================================================ FILE: csrc/flash_attn/src/flash_bwd_launch_template.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" #include "hardware_info.h" #include "flash.h" #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_kernel.h" namespace FLASH_NAMESPACE { // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH #define KERNEL_PARAM_MODIFIER __grid_constant__ #else #define KERNEL_PARAM_MODIFIER #endif // Define a macro for unsupported architecture handling to centralize the error message #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { #if defined(ARCH_SUPPORTS_FLASH) FLASH_NAMESPACE::compute_dq_dk_dv(params); #else FLASH_UNSUPPORTED_ARCH #endif } DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif } template __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { FLASH_NAMESPACE::compute_dot_do_o(params); } template __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { FLASH_NAMESPACE::clear_dKVaccum(params); } template __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { FLASH_NAMESPACE::convert_dQ(params, nsplits); } template __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { FLASH_NAMESPACE::convert_dKV(params); } template void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; int gridDimx = num_n_block; if (params.deterministic) { int num_sm = get_num_sm(get_current_device()); gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h); } dim3 grid_n(gridDimx, params.b, params.h); if (!params.deterministic) { flash_bwd_dot_do_o_kernel<<>>(params); } else { flash_bwd_dot_do_o_kernel<<>>(params); } C10_CUDA_KERNEL_LAUNCH_CHECK(); // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not // a multiple of kBlockN, we'll need to apply mask in the loop. const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); }); }); auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_BACKWARD run_flash_bwd_seqk_parallel(params, stream); #endif } template void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_bwd, Is_dropout, Is_causal>(params, stream); } } else { // 96 KB run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { run_flash_bwd, Is_dropout, Is_causal>(params, stream); // This has a lot of register spilling // run_flash_bwd, Is_dropout>(params, stream); } else { // if (params.h == params.h_k) { // run_flash_bwd, Is_dropout>(params, stream); run_flash_bwd, Is_dropout, Is_causal>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); // } else { // } } }); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); } template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { if constexpr(!Is_dropout) { // 92KB run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { // 116 KB // This is faster for dropout since we don't have many registers to spare run_flash_bwd, Is_dropout, Is_causal>(params, stream); } } else { run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // run_flash_bwd>(params, stream); // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. // run_flash_bwd>(params, stream); if (max_smem_per_block >= 144 * 1024) { run_flash_bwd, Is_dropout, Is_causal>(params, stream); // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); // run_flash_bwd, Is_dropout>(params, stream); } else { // run_flash_bwd, Is_dropout>(params, stream); run_flash_bwd, Is_dropout, Is_causal>(params, stream); } // run_flash_bwd>(params, stream); // run_flash_bwd>(params, stream); }); } template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_bwd, Is_dropout, Is_causal>(params, stream); } }); } template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem run_flash_bwd, Is_dropout, Is_causal>(params, stream); } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. if constexpr (!Is_dropout) { run_flash_bwd, false, Is_causal>(params, stream); } } }); } } // namespace FLASH_NAMESPACE { ================================================ FILE: csrc/flash_attn/src/flash_bwd_preprocess_kernel.h ================================================ /*************************************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include #include #include #include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, Tensor &dP_sum, const int gdP_col_stride, const float scale) { static_assert(Layout0::rank == 3, "Only support 3D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) // The last coordinate is the "page". Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), make_layout(get<0>(do_.layout()), get<2>(do_.layout())))); Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); Tensor do_fp32 = FLASH_NAMESPACE::convert_type(do_reshaped); Tensor o_fp32 = FLASH_NAMESPACE::convert_type(o_reshaped); #pragma unroll for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); #pragma unroll for (int ni = 1; ni < size<1>(do_reshaped); ni++) { dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); } FLASH_NAMESPACE::SumOp sum_op; dP_sum_cur = FLASH_NAMESPACE::Allreduce::run(dP_sum_cur, sum_op) * scale; if (threadIdx.x % THREADS_PER_ROW == 0) { dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. // This is used in the case where we want to parallelize the backward across seqlen_k. template inline __device__ void compute_dot_do_o(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t dq_accum_batch_stride = static_cast(params.seqlen_q_rounded) * params.h * params.d_rounded; const index_t dq_accum_row_stride = static_cast(params.h) * params.d_rounded; const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded; // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); // TODO: careful, we're zeroing out dQaccum with type float4, but when // we do atomicAdds, we use type float. The layouts are different. Check this. typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); // Allocate predicate tensors for k Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); // Set predicates for k bounds #pragma unroll for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} Tensor tdOrdO = make_fragment_like(tdOgdO); Tensor tdOrO = make_fragment_like(tdOgO); FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); FLASH_NAMESPACE::copy( gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM ); // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, // so that (dP - dP_sum) is on the same scale. dot_do_o(tdOrdO, tdOrO, dP_sum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); if (Clear_dQaccum) { // We're actually not zero'ing out all of dQaccum, but only the part that we're going to // do atomicAdds on. Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void clear_dKVaccum(const Params ¶ms) { using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); Tensor zero = make_fragment_like(tdKgdKaccum); clear(zero); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert dQ from dQaccum (in float) to fp16/bf16. // This is used in the case where we want to parallelize the backward across seqlen_k. template inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t dq_accum_batch_stride = static_cast(params.seqlen_q_rounded) * params.h * params.d_rounded; const index_t dq_accum_row_stride = static_cast(params.h) * params.d_rounded; const index_t row_offset_dq_accum = binfo.q_offset(dq_accum_batch_stride, dq_accum_row_stride, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * dq_accum_row_stride + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, make_stride(params.h * params.d_rounded, _1{})); Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQ{}); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); typename Kernel_traits::TiledMmadQ tiled_mma_dq; auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); clear(acc_dq); for (int s = 0; s < nsplits; ++s) { cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; } #pragma unroll for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); __syncthreads(); Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); #pragma unroll for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. // This is used in the case where we want to parallelize the backward across seqlen_q. template inline __device__ void convert_dKV(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; const BlockInfo binfo(params, bidb); if (n_block * kBlockN >= binfo.actual_seqlen_k) return; const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), Shape, Int>{}, make_stride(params.dk_row_stride, _1{})); Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), Shape, Int>{}, make_stride(params.dv_row_stride, _1{})); Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{}); Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdKV{}); Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); typename Kernel_traits::TiledMmadKV tiled_mma_dkv; auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); #pragma unroll for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; } #pragma unroll for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; } // Convert acc_dk from fp32 to fp16 Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); __syncthreads(); Tensor tdKrdK = make_tensor(shape(tdKgdK)); Tensor tdVrdV = make_tensor(shape(tdVgdV)); cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); #pragma unroll for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); FLASH_NAMESPACE::copy( gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN ); } } // namespace flash ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim128(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim192(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim256(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim64(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template<> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim96(params, stream); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_kernel.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include "philox_unpack.cuh" // For at::cuda::philox::unpack #include #include #include #include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" #include "mask.h" #include "dropout.h" #include "rotary.h" namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. // Otherwise, it's written as (h, b, seqlen_q). const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) ); auto lse_layout = make_layout(lse_shape, lse_stride); Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); } template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; auto seed_offset = at::cuda::philox::unpack(params.philox_args); FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, bidb, bidh, tidx, params.h); // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might // exit early and no one saves the rng states. if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { params.rng_state[0] = std::get<0>(seed_offset); params.rng_state[1] = std::get<1>(seed_offset); } const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_tensor(shape(tOgO)); clear(tOrO); // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } } return; } // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.q_row_stride, params.q_head_stride, _1{})); Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), make_shape(binfo.actual_seqlen_k, params.h_k, params.d), make_stride(params.k_row_stride, params.k_head_stride, _1{})); Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), make_shape(binfo.actual_seqlen_k, params.h_k, params.d), make_stride(params.v_row_stride, params.v_head_stride, _1{})); Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling // auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // // PREDICATES // // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { // printf("%d ", get<0>(tScQ(i))); // } // printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { // printf("%d ", get<1>(tScQ(i))); // } // printf("\n"); // } // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // if (cute::thread(1, 0)) { print(tQsQ); } // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // if (cute::thread0()) { print(sQNoSwizzle); } if (Kernel_traits::Share_Q_K_smem) { FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { FLASH_NAMESPACE::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } clear(acc_o); FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor rP_drop = make_fragment_like(rP); cute::copy(rP, rP_drop); dropout.template apply_dropout( rP_drop, block_row_idx, block_col_idx, kNWarps ); cute::copy(rP_drop, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor rP_drop = make_fragment_like(rP); cute::copy(rP, rP_drop); dropout.template apply_dropout( rP_drop, block_row_idx, block_col_idx, kNWarps ); cute::copy(rP_drop, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; using GmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum >; using ElementO = std::conditional_t; const BlockInfo binfo(params, bidb); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } if (m_block * kBlockM >= binfo.actual_seqlen_q) return; const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; const int n_block_min = !Is_local ? n_split_idx * n_blocks_per_split : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); if (Is_causal || Is_local) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. // Otherwise we might read OOB elements from gK and gV, // or get wrong results when we combine gOaccum from different blocks. const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); GmemTiledCopyO gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); clear(tOrOaccum); // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); #pragma unroll for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } } return; } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.q_row_stride, params.q_head_stride, _1{})); Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling // auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // PREDICATES // // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue // Copy from Knew to K, optionally apply rotary embedding. typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. // This maps to accessing the first 64 rows of knew_ptr. Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.knew_ptr) + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), Shape, Int>{}, make_stride(params.knew_row_stride, _1{})); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast(params.vnew_ptr) + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { FLASH_NAMESPACE::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { FLASH_NAMESPACE::copy_w_min_idx( tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); } else { if (params.is_rotary_interleaved) { // Don't clear OOB_K because we're writing to global memory FLASH_NAMESPACE::copy_rotary_interleaved( tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); } else { // Don't clear OOB_K because we're writing to global memory FLASH_NAMESPACE::copy_rotary_contiguous( tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim ); tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); } } tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; const int offset_diff = block_table_offset_next - block_table_offset_cur; tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; } } } // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); tKgK.data() = tKgK_data; tVgV.data() = tVgV_data; } // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); if (params.is_rotary_interleaved) { FLASH_NAMESPACE::copy_rotary_interleaved( tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } else { FLASH_NAMESPACE::copy_rotary_contiguous( tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d, params.rotary_dim ); } } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // FLASH_NAMESPACE::cp_async_wait<0>(); // __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } // __syncthreads(); clear(acc_o); FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; FLASH_NAMESPACE::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = n_block * kBlockN / params.page_block_size; const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } if constexpr (Is_softcap){ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } // __syncthreads(); if (n_block > n_block_min) { // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } // We have key_padding_mask so we'll need to Check_inf masking_step == 0 ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { --n_block; break; } } // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); // Advance gV if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = n_block * kBlockN / params.page_block_size; const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); FLASH_NAMESPACE::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); if constexpr (Is_softcap){ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { // Advance gK if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; } FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); FLASH_NAMESPACE::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning using SmemTiledCopyO = std::conditional_t< !Split, typename Kernel_traits::SmemCopyAtomO, typename Kernel_traits::SmemCopyAtomOaccum >; auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sOaccum is larger than sQ, so we need to syncthreads here // TODO: allocate enough smem for sOaccum if constexpr (Split) { __syncthreads(); } cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) ) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), Shape>{}, Stride<_1>{}); // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } GmemTiledCopyO gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); __syncthreads(); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting // them to have the same number of threads or have to traverse the attention matrix // in the same order. // In the Philox RNG, we use the offset to store the batch, head, and the lane id // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; // The block index for the head. const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; FLASH_NAMESPACE::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; constexpr int kMaxSplits = 1 << Log_max_splits; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNThreads = Kernel_traits::kNThreads; static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); static_assert(kNThreads == 128, "We assume that each block has 128 threads"); // Shared memory. // kBlockM + 1 instead of kBlockM to reduce bank conflicts. __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; // The thread and block index. const int tidx = threadIdx.x; const int bidx = blockIdx.x; const index_t lse_size = params.b * params.h * params.seqlen_q; const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, make_stride(lse_size, _1{})); // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. Layout flat_layout = make_layout(lse_size); Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; // Read the LSE values from gmem and store them in shared memory, then transpose them. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE[row][col] = lse; } // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } __syncthreads(); Tensor lse_accum = make_tensor(Shape>{}); constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, // kBlockM rows, so each time we load we can load 128 / kBlockM rows). // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; // static_assert(kThreadsPerSplit <= 32); static_assert(kRowsPerLoadTranspose <= 32); static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } } // Compute the logsumexp of the LSE along the split dimension. ElementAccum lse_max = lse_accum(0); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum = expf(lse_accum(0) - lse_max); #pragma unroll for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { if (params.unpadded_lse) { const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; if (lse_offset < lse_size) { gLSE_unpadded(lse_offset) = lse_logsum; } } else { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } } __syncthreads(); const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), Shape, Int>{}, Stride, _1>{}); constexpr int kBlockN = kNThreads / kBlockM; using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); Tensor tOrO = make_tensor(shape(tOgOaccum)); Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); clear(tOrO); // Predicates Tensor cOaccum = make_identity_tensor(Shape, Int>{}); // Repeat the partitioning with identity layouts Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } } // Load Oaccum in then scale and accumulate to O for (int split = 0; split < params.num_splits; ++split) { FLASH_NAMESPACE::copy( gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM ); #pragma unroll for (int m = 0; m < size<1>(tOrOaccum); ++m) { int row = get<0>(tOcOaccum(0, m, 0)); ElementAccum lse_scale = sLSE[split][row]; #pragma unroll for (int k = 0; k < size<2>(tOrOaccum); ++k) { #pragma unroll for (int i = 0; i < size<0>(tOrOaccum); ++i) { tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); } } // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } } tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; } // if (cute::thread0()) { print_tensor(tOrO); } Tensor rO = FLASH_NAMESPACE::convert_type(tOrO); // Write to gO #pragma unroll for (int m = 0; m < size<1>(rO); ++m) { const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); if (idx < params.b * params.h * params.seqlen_q) { const int batch_idx = idx / (params.h * params.seqlen_q); const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; // The index to the rows of Q const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; #pragma unroll for (int k = 0; k < size<2>(rO); ++k) { if (Is_even_K || tOpOaccum(k)) { const int col = get<1>(tOcOaccum(0, m, k)); Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), Shape(rO))::value>>{}, Stride<_1>{}); // TODO: Should check if this is using vectorized store, but it seems pretty fast copy(rO(_, m, k), gO); // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); } } } } } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_launch_template.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK #include "static_switch.h" #include "hardware_info.h" #include "flash.h" #include "flash_fwd_kernel.h" namespace FLASH_NAMESPACE { // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define ARCH_SUPPORTS_FLASH #define KERNEL_PARAM_MODIFIER __grid_constant__ #else #define KERNEL_PARAM_MODIFIER #endif // Define a macro for unsupported architecture handling to centralize the error message #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); // Use a macro to clean up kernel definitions #define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif } DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d\n", smem_size); // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // Will only return softmax if dropout, to reduce compilation time. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates // If Is_local, set Is_causal to false auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); }); }); }); } template void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); constexpr size_t smem_size = Kernel_traits::kSmemSize; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); }); }); }); }); if (params.num_splits > 1) { // We want kBlockM to be as small as possible for more parallelism. // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. // If headdim is divisible by 64, then we set kBlockM = 8, etc. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 8) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 16) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 32) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 64) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 128) { flash_fwd_splitkv_combine_kernel<<>>(params); } C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } } template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower // Using block size (64 x 256) is 27% slower for seqlen=2k // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); } template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { if constexpr(!Is_causal) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // These two are always slower // run_flash_fwd>(params, stream); // run_flash_fwd>(params, stream); }); } template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. if (is_sm8x) { if constexpr(!Is_causal) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // 1st ones are good for H100, A100 // 2nd one is good for A6000 bc we get slightly better occupancy } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } }); } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if constexpr(!Is_dropout) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // run_flash_fwd>(params, stream); // run_flash_fwd>(params, stream); // run_flash_fwd>(params, stream); }); } template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); int max_smem_per_sm, max_smem_per_block; cudaError status_ = cudaDeviceGetAttribute( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); status_ = cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { run_flash_fwd, Is_dropout, Is_causal>(params, stream); } // 64 KB // run_flash_fwd, Is_dropout, Is_causal>(params, stream); // 96 KB // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "namespace_config.h" #include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/generate_kernels.py ================================================ import argparse import itertools from dataclasses import dataclass from pathlib import Path from typing import List, Optional DTYPE_MAP = { "fp16": "cutlass::half_t", "bf16": "cutlass::bfloat16_t", } SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] IS_CAUSAL = ["false", "true"] NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' def get_fwd_template() -> str: return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE {{ template<> void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} }} // namespace FLASH_NAMESPACE""" def get_fwd_split_template() -> str: return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" namespace FLASH_NAMESPACE {{ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); }} // namespace FLASH_NAMESPACE""" def get_bwd_template() -> str: return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h" namespace FLASH_NAMESPACE {{ template<> void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} }} // namespace FLASH_NAMESPACE""" @dataclass class Kernel: sm: int dtype: str head_dim: int is_causal: bool direction: str @property def template(self) -> str: template_funcs = { "fwd": get_fwd_template, "bwd": get_bwd_template, "fwd_split": get_fwd_split_template } template_func = template_funcs[self.direction] return template_func().format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal ) @property def filename(self) -> str: return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: for direction in ["fwd", "fwd_split", "bwd"]: for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """// Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py"\n""" content = prelude + kernel.template (autogen_dir / kernel.filename).write_text(content) def main(output_dir: Optional[str]) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) for kernel in get_all_kernels(): write_kernel(kernel, output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate_kernels", description="Generate the flash_attention kernels template instantiations", ) parser.add_argument( "-o", "--output_dir", required=False, help="Where to generate the kernels " " will default to the current directory ", ) args = parser.parse_args() main(args.output_dir) ================================================ FILE: csrc/flash_attn/src/hardware_info.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #if !defined(__CUDACC_RTC__) #include "cuda_runtime.h" #endif #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, \ cudaGetErrorString(status_)); \ exit(1); \ } \ } while (0) inline int get_current_device() { int device; CHECK_CUDA(cudaGetDevice(&device)); return device; } inline std::tuple get_compute_capability(int device) { int capability_major, capability_minor; CHECK_CUDA(cudaDeviceGetAttribute(&capability_major, cudaDevAttrComputeCapabilityMajor, device)); CHECK_CUDA(cudaDeviceGetAttribute(&capability_minor, cudaDevAttrComputeCapabilityMinor, device)); return {capability_major, capability_minor}; } inline int get_num_sm(int device) { int multiprocessor_count; CHECK_CUDA(cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device)); return multiprocessor_count; } ================================================ FILE: csrc/flash_attn/src/kernel_traits.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include using namespace cute; template struct Flash_kernel_traits { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using Element = elem_type; static constexpr bool Has_cp_async = true; #else using Element = cutlass::half_t; static constexpr bool Has_cp_async = false; #endif using ElementAccum = float; using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; #else using MMA_Atom_Arch = MMA_Atom; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #else using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #endif }; // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true template > struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group Tile, _16, _16>>; using SmemLayoutAtomQ = decltype( composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 using SmemLayoutVtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom, Element>; using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // No_double_buffer is another option to reduce smem usage, but will slow things down. template > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQdO = decltype(tile_to_shape( SmemLayoutAtomQdO{}, make_shape(Int{}, Int{}))); using SmemLayoutAtomKV = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); using SmemLayoutKtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; // Temporarily disabling this for hdim 256 on sm86 and sm89 // static_assert(kBlockN >= 64); static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); using SmemLayoutPdStransposed = decltype( composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); using SmemCopyAtomPdS = Copy_Atom, elem_type>; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdKV = Copy_Atom, elem_type>; using SmemLayoutAtomdQ = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem // to affect speed in practice. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row Stride< _8, _1>>, Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store }; //////////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: csrc/flash_attn/src/mask.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "namespace_config.h" #include namespace FLASH_NAMESPACE { using namespace cute; template __forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { tensor(mi, make_coord(j, nj)) = -INFINITY; } } } } } template __forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } } // if (cute::thread0()) { // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); // print(tensor(make_coord(i, mi), _)); // // print(tensor(_, j + nj * size<1, 0>(tensor))); // } } } } template __forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride) { // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, max_seqlen_q, warp_row_stride, -1, 0); } template __forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { tensor(mi, ni) = -INFINITY; } } // if (cute::thread0()) { // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); // print(tensor(_, make_coord(j, ni))); // // print(tensor(_, j + ni * size<1, 0>(tensor))); // } } } template struct Mask { const int max_seqlen_k, max_seqlen_q; const int window_size_left, window_size_right; const float alibi_slope; __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, const int window_size_left, const int window_size_right, const float alibi_slope=0.f) : max_seqlen_k(max_seqlen_k) , max_seqlen_q(max_seqlen_q) , window_size_left(window_size_left) , window_size_right(window_size_right) , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { }; // Causal_mask: whether this particular iteration needs causal masking template __forceinline__ __device__ void apply_mask(Tensor &tensor_, const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); static_assert(Layout::rank == 3, "Only support 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Col_idx_only) { #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // No causal, no local if constexpr (Has_alibi) { tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } } } } } } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if constexpr (Has_alibi) { if constexpr (Is_causal) { tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; } else { tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); } } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } } } } } } } }; }; } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/namespace_config.h ================================================ /** * @file flash_namespace_config.h * @brief Configuration file for Flash namespace management and isolation * * This header provides configuration macros for managing the Flash namespace * across a codebase. It allows for flexible namespace naming and provides * utilities for namespace declaration and scoping. * * Usage Examples: * * 1. Basic namespace wrapping: * @code * BEGIN_FLASH_NAMESPACE * class FlashDevice { * // Implementation * }; * END_FLASH_NAMESPACE * @endcode * * 2. Accessing types within the namespace: * @code * FLASH_NAMESPACE_ALIAS(FlashDevice) device; * @endcode * * 3. Defining content within namespace scope: * @code * FLASH_NAMESPACE_SCOPE( * struct Configuration { * uint32_t size; * bool enabled; * }; * ) * @endcode * * 4. Custom namespace name: * @code * #define FLASH_NAMESPACE custom_flash * #include "flash_namespace_config.h" * @endcode * * Configuration: * - The default namespace is 'flash' if FLASH_NAMESPACE is not defined * - Define FLASH_NAMESPACE before including this header to customize the * namespace name * * Best Practices: * - Include this header in all files that need access to the Flash namespace * */ #pragma once #ifndef FLASH_NAMESPACE_CONFIG_H #define FLASH_NAMESPACE_CONFIG_H // Set default namespace to flash #ifndef FLASH_NAMESPACE #define FLASH_NAMESPACE flash #endif #define FLASH_NAMESPACE_ALIAS(name) FLASH_NAMESPACE::name #define FLASH_NAMESPACE_SCOPE(content) \ namespace FLASH_NAMESPACE { \ content \ } #endif // FLASH_NAMESPACE_CONFIG_H ================================================ FILE: csrc/flash_attn/src/philox.cuh ================================================ // Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h #pragma once // Philox CUDA. #include "namespace_config.h" namespace FLASH_NAMESPACE { struct ull2 { unsigned long long x; unsigned long long y; }; __forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" : "=l"(tmp) : "r"(a), "r"(b)); res = (uint2*)(&tmp); return *res; } __forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; return ret; } __forceinline__ __device__ uint4 philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; constexpr unsigned long kPhilox10B = 0xBB67AE85; uint2 key = reinterpret_cast(seed); uint4 counter; ull2 *tmp = reinterpret_cast(&counter); tmp->x = offset; tmp->y = subsequence; #pragma unroll for (int i = 0; i < 6; i++) { counter = philox_single_round(counter, key); key.x += (kPhilox10A); key.y += (kPhilox10B); } uint4 output = philox_single_round(counter, key); return output; } } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/philox_unpack.cuh ================================================ // This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh #pragma once #include ================================================ FILE: csrc/flash_attn/src/rotary.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include "namespace_config.h" #include "utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, Tensor &D, Tensor const &Cos, Tensor const &Sin, Tensor const &identity_MN, const int max_MN, const int min_MN, const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 Tensor rCos = make_fragment_like(Cos); Tensor rSin = make_fragment_like(Sin); Tensor rS = make_fragment_like(S); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { cute::copy(S(_, m, k), rS(_, m, k)); if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { cute::copy(Cos(_, m, k), rCos(_, m, k)); cute::copy(Sin(_, m, k), rSin(_, m, k)); Tensor S_fp32 = convert_type(rS(_, m, k)); Tensor cos_fp32 = convert_type(rCos(_, m, k)); Tensor sin_fp32 = convert_type(rSin(_, m, k)); #pragma unroll for (int i = 0; i < size<0>(rS) / 2; ++i) { float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); S_fp32(2 * i) = real; S_fp32(2 * i + 1) = imag; } // Idk but I need to copy for the convert_type to work Tensor S_fp32_copy = make_fragment_like(S_fp32); cute::copy(S_fp32, S_fp32_copy); using T = typename Engine0::value_type; Tensor S_og_type = convert_type(S_fp32_copy); cute::copy(S_og_type, rS(_, m, k)); } cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, Tensor &D, Tensor const &Cos, Tensor const &Sin, Tensor const &identity_MN, const int max_MN, const int min_MN, const int dim, const int rotary_dim) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 Tensor rCos = make_fragment_like(Cos); Tensor rSin = make_fragment_like(Sin); Tensor rS = make_fragment_like(S); Tensor rS_other = make_fragment_like(rS(_, 0, 0)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { cute::copy(S(_, m, k), rS(_, m, k)); if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); cute::copy(gS_other, rS_other); // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); cute::copy(gCos, rCos(_, m, k)); cute::copy(gSin, rSin(_, m, k)); // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } Tensor S_fp32 = convert_type(rS(_, m, k)); Tensor S_other_fp32 = convert_type(rS_other); Tensor cos_fp32 = convert_type(rCos(_, m, k)); Tensor sin_fp32 = convert_type(rSin(_, m, k)); #pragma unroll for (int i = 0; i < size<0>(rS); ++i) { S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); } // Idk but I need to copy for the convert_type to work Tensor S_fp32_copy = make_fragment_like(S_fp32); cute::copy(S_fp32, S_fp32_copy); using T = typename Engine0::value_type; Tensor S_og_type = convert_type(S_fp32_copy); cute::copy(S_og_type, rS(_, m, k)); // if (cute::thread0()) { print_tensor(rS(_, m, k)); } } cute::copy(rS(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/softmax.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include "namespace_config.h" #include "philox.cuh" #include "utils.h" namespace FLASH_NAMESPACE { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { summary(mi) = op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template __forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. // The following macro will disable the use of fma. // See: https://github.com/pytorch/pytorch/issues/121558 for more details // This macro is set in PyTorch and not FlashAttention #ifdef UNFUSE_FMA tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); #else tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); #endif } } } // Apply the exp to all the elements. template __forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { MaxOp max_op; max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); #pragma unroll for (int ni = 1; ni < size<1>(tensor); ni++) { max(mi) = max_op(max(mi), tensor(mi, ni)); } max(mi) = Allreduce<4>::run(max(mi), max_op); // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; sum(mi) = 0; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)) This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); sum(mi) += tensor(mi, ni); } SumOp sum_op; sum(mi) = Allreduce<4>::run(sum(mi), sum_op); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; __forceinline__ __device__ Softmax() {}; template __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); if (Is_first) { FLASH_NAMESPACE::template reduce_max(scores, row_max); FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); FLASH_NAMESPACE::reduce_sum(scores, row_sum); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); FLASH_NAMESPACE::template reduce_max(scores, row_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. FLASH_NAMESPACE::reduce_sum(scores, row_sum); } }; template __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } return lse; }; }; } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn/src/static_switch.h ================================================ // Inspired by // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once /// @param COND - a boolean expression to switch by /// @param CONST_NAME - a name given for the constexpr bool variable. /// @param ... - code to execute for true and false /// /// Usage: /// ``` /// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); /// }); /// ``` #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #ifdef FLASHATTENTION_DISABLE_DROPOUT #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define DROPOUT_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_ALIBI #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define ALIBI_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_UNEVEN_K #define EVENK_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ }() #else #define EVENK_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define SOFTCAP_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_LOCAL #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define LOCAL_SWITCH BOOL_SWITCH #endif #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ using elem_type = cutlass::half_t; \ return __VA_ARGS__(); \ } else { \ using elem_type = cutlass::bfloat16_t; \ return __VA_ARGS__(); \ } \ }() #define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \ return __VA_ARGS__(); \ } else if (HEADDIM <= 64) { \ constexpr static int kHeadDim = 64; \ return __VA_ARGS__(); \ } else if (HEADDIM <= 96) { \ constexpr static int kHeadDim = 96; \ return __VA_ARGS__(); \ } else if (HEADDIM <= 128) { \ constexpr static int kHeadDim = 128; \ return __VA_ARGS__(); \ } else if (HEADDIM <= 192) { \ constexpr static int kHeadDim = 192; \ return __VA_ARGS__(); \ } else if (HEADDIM <= 256) { \ constexpr static int kHeadDim = 256; \ return __VA_ARGS__(); \ } \ }() ================================================ FILE: csrc/flash_attn/src/utils.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include #include #include #include #include "namespace_config.h" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace FLASH_NAMESPACE { //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ uint32_t relu2(const uint32_t x); template<> __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); #else asm volatile( \ "{\n" \ "\t .reg .f16x2 sela;\n" \ "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template<> __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); return res; } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template __forceinline__ __device__ uint32_t convert_relu2(const float2 x); template<> __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } template<> __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); if constexpr (mma_shape_K == 8) { return acc_layout; } else { auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) template __forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; // HACK: this requires tensor to be "contiguous" auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void relu_(Tensor &tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; // HACK: this requires tensor to be "contiguous" Tensor tensor_uint32 = recast(tensor); #pragma unroll for (int i = 0; i < size(tensor_uint32); ++i) { tensor_uint32(i) = relu2(tensor_uint32(i)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template __forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // HACK: this requires tensor to be "contiguous" Tensor tensor_float2 = recast(tensor); Tensor out_uint32 = make_tensor(tensor_float2.layout()); #pragma unroll for (int i = 0; i < size(out_uint32); ++i) { out_uint32(i) = convert_relu2(tensor_float2(i)); } Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); #else Tensor out = FLASH_NAMESPACE::convert_type(tensor); FLASH_NAMESPACE::relu_(out); #endif return out; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Blocks until all but N previous cp.async.commit_group operations have committed. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all // (which is equivalent to commit_group then wait_group 0). // Instead we just call cp.async.wait_group 0, which is slightly faster. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 template CUTE_HOST_DEVICE void cp_async_wait() { #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } // TD [2023-04-13]: Strange that the code below can cause race condition. // I think it's because the copies are under an if statement. // if (Is_even_K) { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // copy(tiled_copy, S(_, m, _), D(_, m, _)); // } else if (Clear_OOB_MN) { // clear(D(_, m, _)); // } // } // } else { // It's slightly faster in this case if iterate over K first // #pragma unroll // for (int k = 0; k < size<2>(S); ++k) { // if (predicate_K(k)) { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // copy(tiled_copy, S(_, m, k), D(_, m, k)); // } else if (Clear_OOB_MN) { // clear(D(_, m, k)); // } // } // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN // if (Clear_OOB_MN || Is_even_MN) { // clear(D(_, _, k)); // } else { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { // clear(D(_, m, k)); // } // } // } // } // } // } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy_w_min_idx(Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, const int min_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll for (int m = 0; m < size<1>(S); ++m) { // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(S(_, m, k), D(_, m, k)); } } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ #pragma unroll for (int i = 0; i < size(tensor); ++i) { tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); } } template __forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ #pragma unroll for (int i = 0; i < size(src_tensor); ++i) { dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace FLASH_NAMESPACE ================================================ FILE: csrc/flash_attn_ck/flash_api.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool return_softmax, std::optional gen_); std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool return_softmax, std::optional gen_); std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool deterministic, std::optional gen_, std::optional &rng_state); std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x s softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const float softcap, const bool deterministic, std::optional gen_, std::optional &rng_state); std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &seqlens_k_, // batch_size std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional &cache_batch_idx_, // indices to index into the KV cache std::optional &leftpad_k_, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or batch_size x num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &mha_fwd, "Forward pass"); m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); } ================================================ FILE: csrc/flash_attn_ck/flash_common.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" namespace flash { int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { int device; auto status = hipGetDevice(&device); if(status != hipSuccess) return num_splits; hipDeviceProp_t props{}; status = hipGetDeviceProperties(&props, device); if(status != hipSuccess) return num_splits; // TODO - tile size should match the TileFmhaShape, hardcode for now const int kM0 = 128; const int kN1 = hdim_v; const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; if(num_splits < 1 && p_drop == 0.0f) return num_splits_heuristic_ck( batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); return num_splits; } } // namespace flash ================================================ FILE: csrc/flash_attn_ck/flash_common.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. #include #include #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace flash { inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state) { // Imitate from PyTorch // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 if (arg.captured_) { rng_state[0] = static_cast(*arg.seed_.ptr); rng_state[1] = static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_); } else { rng_state[0] = arg.seed_.val; rng_state[1] = arg.offset_.val; } } inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks // (i.e. it's 11 splits anyway). // So we check if the number of blocks per split is the same as the previous num_splits. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); }; for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (!is_split_eligible(num_splits)) { continue; } if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits); } // namespace flash ================================================ FILE: csrc/flash_attn_ck/mha_bwd.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_bwd.hpp" #include "mask.hpp" fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, std::string dtype, int seqlen_q, int seqlen_k, int batch, int head_size, int nhead_q, int nhead_k, bool has_dropout, bool enable_alibi, bool deterministic) { return fmha_bwd_traits{seqlen_q, seqlen_k, batch, seqlen_q, // max_seqlen_q seqlen_k, // max_seqlen_k head_size, // hdim_q head_size, // hdim_k nhead_q, nhead_k, dtype, false, // is_group_mode mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, false, // has_dbias has_dropout, false, // s_randval deterministic}; } fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, // sizes const int b, const int seqlen_q, const int seqlen_k, const int h, const int h_k, const int hdim, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, std::optional &alibi_slopes_, const at::Tensor out, const at::Tensor softmax_lse, const at::Tensor dout, at::Tensor dq_acc, at::Tensor d, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, float p_dropout, std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_q = q.stride(0); ck_tile::index_t stride_q = q.stride(1); ck_tile::index_t nhead_stride_q = q.stride(2); // k: (batch_size, seqlen_k, nheads_k, hdim) ck_tile::index_t batch_stride_k = k.stride(0); ck_tile::index_t stride_k = k.stride(1); ck_tile::index_t nhead_stride_k = k.stride(2); // v: (batch_size, seqlen_k, nheads_k, hdim) ck_tile::index_t batch_stride_v = v.stride(0); ck_tile::index_t stride_v = v.stride(1); ck_tile::index_t nhead_stride_v = v.stride(2); // o: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_o = out.stride(0); ck_tile::index_t stride_o = out.stride(1); ck_tile::index_t nhead_stride_o = out.stride(2); // lse: (batch_size, nheads, seqlen_q) ck_tile::index_t batch_stride_lse = softmax_lse.stride(0); ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1); // do: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_do = dout.stride(0); ck_tile::index_t stride_do = dout.stride(1); ck_tile::index_t nhead_stride_do = dout.stride(2); // d: (batch_size, nheads, seqlen_q) // CK assume d share the same stride with lse // dq: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_dq = dq.stride(0); ck_tile::index_t stride_dq = dq.stride(1); ck_tile::index_t nhead_stride_dq = dq.stride(2); // dk_expanded: (batch_size, seqlen_k, nheads, hdim) ck_tile::index_t batch_stride_dk = dk.stride(0); ck_tile::index_t stride_dk = dk.stride(1); ck_tile::index_t nhead_stride_dk = dk.stride(2); // dv_expanded: (batch_size, seqlen_k, nheads, hdim) ck_tile::index_t batch_stride_dv = dv.stride(0); ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); // dq_acc: (batch_size, nheads, split, seqlen_q, hdim) ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(0); ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(1); ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2); ck_tile::index_t stride_dq_acc = dq_acc.stride(3); float p_undrop = 1.0 - p_dropout; void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); // alibi_slopes:(batch_size, nheads) or (nhead) stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias out.data_ptr(), softmax_lse.data_ptr(), dout.data_ptr(), d.data_ptr(), nullptr, // rand_val dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, seqlen_q, // max_seqlen_q seqlen_k, // max_seqlen_k hdim, // hdim_q hdim, // hdim_v h, // nhead h_k, // nhead_k softmax_scale, stride_q, stride_k, stride_v, stride_alibi_slopes, stride_o, 0, // stride_randval stride_do, stride_dq_acc, stride_dq, stride_dk, stride_dv, 0, // stride_dbias, FA without bias nhead_stride_q, nhead_stride_k, nhead_stride_v, 0, // nhead_stride_bias, FA without bias nhead_stride_o, 0, // nhead_stride_randval nhead_stride_do, nhead_stride_lse, nhead_stride_dq_acc, nhead_stride_dq, nhead_stride_dk, nhead_stride_dv, 0, // nhead_stride_dbias, FA without dbias batch_stride_q, batch_stride_k, batch_stride_v, 0 , // batch_stride_bias, FA without bias batch_stride_o, 0, // batch_stride_randval batch_stride_do, batch_stride_lse, batch_stride_dq_acc, batch_stride_dq, batch_stride_dk, batch_stride_dv, 0 , // batch_stride_dbias, FA without dbias split_stride_dq_acc, mask.left, mask.right, static_cast(mask.type), p_dropout, p_undrop, drop_seed_offset}; } std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size, 8) const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x seqlen_q std::optional &dq_, // batch_size x seqlen_q x num_heads x head_size std::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, const bool deterministic, std::optional gen_, std::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } const bool is_dropout = p_dropout > 0.0; #ifdef HIPIFY_V2 auto stream = at::cuda::getCurrentCUDAStream().stream(); #else auto stream = at::cuda::getCurrentHIPStream().stream(); #endif auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; const int seqlen_q = sizes[1]; const int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } mask_info mask; if (is_causal) { std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local } // q, k, v, out had been padded in mha_fwd // dq_, dk_, dv_ are also padded tensor CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size); at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { dq = torch::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { dk = torch::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { dv = torch::empty_like(v); } const auto traits = get_ck_fmha_bwd_traits( mask, q_dtype_str, seqlen_q, seqlen_k, batch_size, head_size, num_heads, num_heads_k, is_dropout, alibi_slopes_.has_value(), deterministic); fmha_bwd_launcher launcher(traits); const ck_tile::index_t nsplits = launcher.dq_acc_splits; at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); } else { dk_expanded = dk; dv_expanded = dv; } auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); at::Tensor rng_state; if (rng_state_.has_value()) { rng_state = rng_state_.value(); } else if(is_dropout) { rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); } if (seqlen_q > 0) { auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto args = get_ck_fmha_bwd_args( mask, batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, head_size, q, k, v, alibi_slopes_, out, softmax_lse, dout, dq_accum, softmax_d, dq, dk_expanded, dv_expanded, softmax_scale, p_dropout, drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3}); } return { dq, dk, dv, softmax_d }; } ================================================ FILE: csrc/flash_attn_ck/mha_fwd.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_fwd.hpp" #include "mask.hpp" fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, std::string dtype, int head_size, bool has_dropout, bool has_lse, bool enable_alibi) { return fmha_fwd_traits{head_size, head_size, dtype, false, // is_group_mode true, // is_v_rowmajor false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, bool has_dropout_randval, const mask_info &mask, // sizes const int b, const int seqlen_q, const int seqlen_k, const int h, const int h_k, const int d, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, std::optional &alibi_slopes_, at::Tensor out, at::Tensor softmax_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) // v: (batch_size, seqlen_k, nheads_k, d) // o: (batch_size, seqlen_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) // lse: (batch_size, nheads, seqlen_q) // randval: (batch_size, nheads, seqlen_q, seqlen_k) ck_tile::index_t stride_q = q.stride(1); ck_tile::index_t stride_k = k.stride(1); ck_tile::index_t stride_v = v.stride(1); ck_tile::index_t stride_o = out.stride(1); ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0; ck_tile::index_t nhead_stride_q = q.stride(2); ck_tile::index_t nhead_stride_k = k.stride(2); ck_tile::index_t nhead_stride_v = v.stride(2); ck_tile::index_t nhead_stride_o = out.stride(2); ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; ck_tile::index_t batch_stride_q = q.stride(0); ck_tile::index_t batch_stride_k = k.stride(0); ck_tile::index_t batch_stride_v = v.stride(0); ck_tile::index_t batch_stride_o = out.stride(0); ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias nullptr, // q_descale_ptr nullptr, // k_descale_ptr nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), nullptr, // seqstart_q_ptr nullptr, // seqstart_k_ptr nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr nullptr, // block_scale_seqstart_q_ptr nullptr, // block_scale_seqstart_k_ptr nullptr, // seqstart_v_scale_ptr nullptr, // sink_ptr seqlen_q, seqlen_k, b, seqlen_q, // max_seqlen_q d, // hdim_q d, // hdim_v h, // nhead h_k, // nhead_k softmax_scale, // scale_s 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, stride_alibi_slopes, stride_randval, stride_o, 0, // stride_q_descale 0, // stride_k_descale 0, // stride_v_descale nhead_stride_q, nhead_stride_k, nhead_stride_v, 0, // nhead_stride_bias, FA without bias nhead_stride_randval, nhead_stride_lse, nhead_stride_o, 0, // nhead_stride_q_descale 0, // nhead_stride_k_descale 0, // nhead_stride_v_descale batch_stride_q, batch_stride_k, batch_stride_v, 0, // batch_stride_bias, FA without bias batch_stride_randval, batch_stride_lse, batch_stride_o, 0, // batch_stride_q_descale 0, // batch_stride_k_descale 0, // batch_stride_v_descale mask.left, mask.right, 0, // sink_size static_cast(mask.type), 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset, 0, // block_scale_size_q 0}; // block_scale_size_kv } std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) std::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, const bool return_dropout_randval, std::optional gen_) { auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } mask_info mask; if (is_causal) { // Causal is the special case where window_size_right == 0 and window_size_left < 0. window_size_right = 0; std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value(); const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); } } else { out = torch::empty_like(q); } // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); bool has_lse = true; bool has_dropout = p_dropout > 0.0f; at::Tensor softmax_lse; // TODO - check gradient, only training require lse softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32)); at::Tensor p; if (return_dropout_randval) { TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8)); } else { p = torch::empty({ 0 }, opts); } int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } if (seqlen_k > 0) { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); #ifdef HIPIFY_V2 auto stream = at::cuda::getCurrentCUDAStream().stream(); #else auto stream = at::cuda::getCurrentHIPStream().stream(); #endif ck_tile::stream_config stream_config{stream}; auto traits = get_ck_fmha_fwd_traits( mask, q_dtype_str, head_size, has_dropout, has_lse, alibi_slopes_.has_value()); auto args = get_ck_fmha_fwd_args( has_lse, return_dropout_randval, mask, batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, head_size, q, k, v, alibi_slopes_, out, softmax_lse, p, softmax_scale, p_dropout, drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse, p, rng_state}; } ================================================ FILE: csrc/flash_attn_ck/mha_fwd_kvcache.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_fwd.hpp" #include "rotary.hpp" fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype, int head_size, int rotary_dim, bool is_rotary_interleaved) { rope_enum rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated) : rope_enum::none); return fmha_fwd_appendkv_traits{head_size, head_size, dtype, true, // is_v_rowmajor rope_type}; } fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask, std::string dtype, int head_size, bool has_lse, bool enable_alibi) { return fmha_fwd_splitkv_traits{head_size, head_size, dtype, false, // is_group_mode true, // is_v_rowmajor false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, false, // do_fp8_static_quant false}; // has_sink } fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b, const int seqlen_q, const int seqlen_knew, const int h, const int h_k, const int d, const int rotary_dim, const bool has_mask, const int page_block_size, // device pointers const at::Tensor q, const at::Tensor kcache, const at::Tensor vcache, const at::Tensor knew, const at::Tensor vnew, std::optional &seqlens_k_, std::optional &rotary_cos_, std::optional &rotary_sin_, std::optional &cache_batch_idx_, std::optional &block_table_) { // q: (batch_size, seqlen_q, nheads, d) // kcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d) // vcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d) // knew: (batch_size, seqlen_knew, nheads_k, d) // vnew: (batch_size, seqlen_knew, nheads_k, d) // seqlens_k: (batch_size) // rotary_cos: (seqlen_ro, rotary_dim / 2) // rotary_sin: (seqlen_ro, rotary_dim / 2) // block_table: (batch_size, max_num_blocks_per_seq) fmha_fwd_appendkv_args args; args.q_ptr = q.data_ptr(); args.k_ptr = kcache.data_ptr(); args.knew_ptr = knew.data_ptr(); args.v_ptr = vcache.data_ptr(); args.vnew_ptr = vnew.data_ptr(); args.seqlen_k_ptr = seqlens_k_.has_value() ? seqlens_k_.value().data_ptr() : nullptr; args.seqlen_q = seqlen_q; args.seqlen_knew = seqlen_knew; args.batch = b; args.hdim_q = d; args.hdim_v = d; args.nhead_q = h; args.nhead_k = h_k; args.rotary_cos_ptr = rotary_cos_.has_value() ? rotary_cos_.value().data_ptr() : nullptr; args.rotary_sin_ptr = rotary_sin_.has_value() ? rotary_sin_.value().data_ptr() : nullptr; args.rotary_dim = rotary_dim; args.has_mask = has_mask; if (block_table_.has_value()) { auto block_table = block_table_.value(); args.block_table_ptr = block_table.data_ptr(); args.batch_stride_block_table = block_table.stride(0); args.page_block_size = page_block_size; } else { args.block_table_ptr = nullptr; args.batch_stride_block_table = 0; args.page_block_size = 0; } args.cache_batch_idx = cache_batch_idx_.has_value() ? reinterpret_cast(cache_batch_idx_.value().data_ptr()) : nullptr; args.batch_stride_q = q.stride(0); args.stride_q = q.stride(1); args.nhead_stride_q = q.stride(2); args.batch_stride_k = kcache.stride(0); args.stride_k = kcache.stride(1); args.nhead_stride_k = kcache.stride(2); args.batch_stride_knew = knew.stride(0); args.stride_knew = knew.stride(1); args.nhead_stride_knew = knew.stride(2); args.batch_stride_v = vcache.stride(0); args.stride_v = vcache.stride(1); args.nhead_stride_v = vcache.stride(2); args.batch_stride_vnew = vnew.stride(0); args.stride_vnew = vnew.stride(1); args.nhead_stride_vnew = vnew.stride(2); return args; } fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, const mask_info &mask, const int b, const int seqlen_q, const int seqlen_k, const int h, const int h_k, const int d, const int page_block_size, const int num_splits, float softmax_scale, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor seqlens_k, std::optional &cache_batch_idx_, std::optional &block_table_, std::optional &alibi_slopes_, at::Tensor out, at::Tensor lse, at::Tensor lse_acc, at::Tensor out_acc) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) // v: (batch_size, seqlen_k, nheads_k, d) // o: (batch_size, seqlen_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) // lse: (batch_size, nheads, seqlen_q) // lse_acc: (split, batch_size, nheads, seqlen_q) // o_acc: (split, batch_size, nheads, seqlen_q, d) fmha_fwd_splitkv_args args; args.q_ptr = q.data_ptr(); args.k_ptr = k.data_ptr(); args.v_ptr = v.data_ptr(); args.bias_ptr = nullptr; args.lse_acc_ptr = lse_acc.data_ptr(); args.o_acc_ptr = out_acc.data_ptr(); args.lse_ptr = nullptr; args.o_ptr = out.data_ptr(); args.sink_ptr = nullptr; if (block_table_.has_value()) { auto block_table = block_table_.value(); args.block_table_ptr = block_table.data_ptr(); args.batch_stride_block_table = block_table.stride(0); args.page_block_size = page_block_size; } else { args.block_table_ptr = nullptr; args.batch_stride_block_table = 0; args.page_block_size = 0; } args.cache_batch_idx = cache_batch_idx_.has_value() ? cache_batch_idx_.value().data_ptr() : nullptr; args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; args.seqlen_k_ptr = seqlens_k.data_ptr(); args.seqlen_q = seqlen_q; args.seqlen_k = seqlen_k; args.batch = b; args.max_seqlen_q = seqlen_q; args.hdim_q = d; args.hdim_v = d; args.nhead_q = h; args.nhead_k = h_k; args.num_splits = num_splits; args.scale_s = softmax_scale; args.scale_p = 1; args.scale_o = 1; args.batch_stride_q = q.stride(0); args.stride_q = q.stride(1); args.nhead_stride_q = q.stride(2); args.batch_stride_k = k.stride(0); args.stride_k = k.stride(1); args.nhead_stride_k = k.stride(2); args.batch_stride_v = v.stride(0); args.stride_v = v.stride(1); args.nhead_stride_v = v.stride(2); args.batch_stride_o = out.stride(0); args.stride_o = out.stride(1); args.nhead_stride_o = out.stride(2); args.batch_stride_bias = 0; args.stride_bias = 0; args.nhead_stride_bias = 0; args.batch_stride_lse = 0; args.nhead_stride_lse = 0; args.split_stride_lse_acc = lse_acc.stride(0); args.batch_stride_lse_acc = lse_acc.stride(1); args.nhead_stride_lse_acc = lse_acc.stride(2); args.split_stride_o_acc = out_acc.stride(0); args.batch_stride_o_acc = out_acc.stride(1); args.nhead_stride_o_acc = out_acc.stride(2); args.stride_o_acc = out_acc.stride(3); if (has_lse) { args.lse_ptr = lse.data_ptr(); args.batch_stride_lse = lse.stride(0); args.nhead_stride_lse = lse.stride(1); } if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); args.bias_ptr = alibi_slopes.data_ptr(); args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } args.window_size_left = mask.left; args.window_size_right = mask.right; args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; } std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size std::optional &seqlens_k_, // batch_size std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional &cache_batch_idx_, // indices to index into the KV cache std::optional & /*leftpad_k_*/, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or batch_size x num_heads std::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits) { auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); at::Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); block_table = block_table_.value(); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // causal=true is the same as causal=false in this case if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } mask_info mask; if (is_causal) { // Causal is the special case where window_size_right == 0 and window_size_left < 0. window_size_right = 0; std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; } if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); if (!paged_KV) { CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); } else { CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { q_padded = q; kcache_padded = kcache; vcache_padded = vcache; } at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_8x = round_multiple(head_size_og, 8); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); // TODO - check gradient, only training require lse bool has_lse = true; auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); int seqlen_knew = 0; at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in"); TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache"); k = k_.value(); v = v_.value(); TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query"); TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query"); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension"); seqlen_knew = k.size(1); CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og); CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og); if (head_size_og % 8 != 0) { k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); } else { k_padded = k; v_padded = v; } } if (seqlens_k_.has_value()) { auto seqlens_k = seqlens_k_.value(); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); CHECK_DEVICE(seqlens_k); CHECK_CONTIGUOUS(seqlens_k); CHECK_SHAPE(seqlens_k, batch_size); } int rotary_dim = 0; if (rotary_cos_.has_value()) { TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); rotary_dim = rotary_cos.size(1) * 2; TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim"); TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); CHECK_SHAPE(rotary_cos, seqlen_ro, rotary_dim / 2); CHECK_CONTIGUOUS(rotary_cos); TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, rotary_dim / 2); CHECK_CONTIGUOUS(rotary_sin); TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query"); } if (cache_batch_idx_.has_value()) { auto cache_batch_idx = cache_batch_idx_.value(); CHECK_DEVICE(cache_batch_idx); CHECK_CONTIGUOUS(cache_batch_idx); TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); } num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); // Keep references to these tensors to extend their lifetime auto softmax_lse_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); auto out_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); auto stream = at::cuda::getCurrentCUDAStream().stream(); ck_tile::stream_config stream_config{stream}; if (seqlen_knew > 0 || rotary_dim > 0) { auto appendkv_traits = get_ck_fmha_fwd_appendkv_traits(q_dtype_str, head_size_8x, rotary_dim, is_rotary_interleaved); auto appendkv_args = get_ck_fmha_fwd_appendkv_args( batch_size, seqlen_q, seqlen_knew, num_heads, num_heads_k, head_size_8x, rotary_dim, mask.type != mask_enum::no_mask, page_block_size, q_padded, kcache_padded, vcache_padded, k_padded, v_padded, seqlens_k_, rotary_cos_, rotary_sin_, cache_batch_idx_, block_table_); fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); } // seqlens_k_ is the seqlen of kvcache. We need to add seqlen_knew for before attention auto append_seqlens_k = torch::empty({batch_size}, opts.dtype(torch::kInt32)); if (seqlens_k_.has_value()) append_seqlens_k = seqlens_k_.value() + seqlen_knew; else append_seqlens_k.fill_(seqlen_knew); // we use splitkv even num_splits == 1, because fmha_fwd() does not support seqlen_k_ in batch mode auto splitkv_traits = get_ck_fmha_fwd_splitkv_traits(mask, q_dtype_str, head_size_8x, has_lse, alibi_slopes_.has_value()); auto splitkv_args = get_ck_fmha_fwd_splitkv_args( has_lse, mask, batch_size, seqlen_q, seqlen_k, num_heads, num_heads_k, head_size_8x, page_block_size, num_splits, softmax_scale, q_padded, kcache_padded, vcache_padded, append_seqlens_k, cache_batch_idx_, block_table_, alibi_slopes_, out, softmax_lse, softmax_lse_accum, out_accum); fmha_fwd_splitkv(splitkv_traits, splitkv_args, stream_config); if (head_size_og % 8 != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); if (out_.has_value()) { out_.value().copy_(out); } if (k_.has_value()) { // It's expensive to copy the KV cache here for the case where head size not divisible by 8, // but we don't expect to get this case in practice. This is just so that the code works for that case. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)})); } } if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } return {out, softmax_lse}; } ================================================ FILE: csrc/flash_attn_ck/mha_varlen_bwd.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_bwd.hpp" #include "mask.hpp" fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, std::string dtype, int seqlen_q, int seqlen_k, int batch, int max_seqlen_q, int max_seqlen_k, int head_size, int nhead_q, int nhead_k, bool has_dropout, bool enable_alibi, bool deterministic) { return fmha_bwd_traits{seqlen_q, seqlen_k, batch, max_seqlen_q, max_seqlen_k, head_size, // hdim_q head_size, // hdim_k nhead_q, nhead_k, dtype, true, // is_group_mode mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, false, // has_dbias has_dropout, false, // s_randval deterministic}; } fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, // sizes const int b, const int max_seqlen_q, const int max_seqlen_k, const int h, const int h_k, const int hdim, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor seqlens_q, const at::Tensor seqlens_k, std::optional &alibi_slopes_, const at::Tensor out, const at::Tensor softmax_lse, const at::Tensor dout, at::Tensor dq_acc, at::Tensor d, at::Tensor dq, at::Tensor dk, at::Tensor dv, float softmax_scale, float p_dropout, std::pair drop_seed_offset) { ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); // q: (total_q, nheads, hdim) ck_tile::index_t batch_stride_q = 0; ck_tile::index_t stride_q = q.stride(0); ck_tile::index_t nhead_stride_q = q.stride(1); // k: (total_k, nheads_k, hdim) ck_tile::index_t batch_stride_k = 0; ck_tile::index_t stride_k = k.stride(0); ck_tile::index_t nhead_stride_k = k.stride(1); // v: (total_k, nheads_k, hdim) ck_tile::index_t batch_stride_v = 0; ck_tile::index_t stride_v = v.stride(0); ck_tile::index_t nhead_stride_v = v.stride(1); // o: (total_q, nheads, hdim) ck_tile::index_t batch_stride_o = 0; ck_tile::index_t stride_o = out.stride(0); ck_tile::index_t nhead_stride_o = out.stride(1); // lse: (nheads, total_q) ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0); // do: (total_q, nheads, hdim) ck_tile::index_t batch_stride_do = 0; ck_tile::index_t stride_do = dout.stride(0); ck_tile::index_t nhead_stride_do = dout.stride(1); // d: (batch_size, nheads, max_seqlen_q) // CK assume d share the same stride with lse // dq: (total_q, nheads, hdim) ck_tile::index_t batch_stride_dq = 0; ck_tile::index_t stride_dq = dq.stride(0); ck_tile::index_t nhead_stride_dq = dq.stride(1); // dk_expanded: (total_k, nheads, hdim) ck_tile::index_t batch_stride_dk = 0; ck_tile::index_t stride_dk = dk.stride(0); ck_tile::index_t nhead_stride_dk = dk.stride(1); // dv_expanded: (total_k, nheads, hdim) ck_tile::index_t batch_stride_dv = 0; ck_tile::index_t stride_dv = dv.stride(0); ck_tile::index_t nhead_stride_dv = dv.stride(1); // dq_acc: (nheads, split, total_q, hdim) ck_tile::long_index_t batch_stride_dq_acc = 0; ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t split_stride_dq_acc = dq_acc.stride(1); ck_tile::index_t stride_dq_acc = dq_acc.stride(2); float p_undrop = 1.0 - p_dropout; void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); // alibi_slopes:(batch_size, nheads) or (nhead) stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias out.data_ptr(), softmax_lse.data_ptr(), dout.data_ptr(), d.data_ptr(), nullptr, // rand_val dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc seqlens_q.data_ptr(), // seqstart_q_ptr seqlens_k.data_ptr(), // seqstart_k_ptr nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr total_q, total_k, b, max_seqlen_q, // max_seqlen_q max_seqlen_k, // max_seqlen_k hdim, // hdim_q hdim, // hdim_v h, // nhead h_k, // nhead_k softmax_scale, stride_q, stride_k, stride_v, stride_alibi_slopes, stride_o, 0, // stride_randval stride_do, stride_dq_acc, stride_dq, stride_dk, stride_dv, 0, // stride_dbias, FA without bias nhead_stride_q, nhead_stride_k, nhead_stride_v, 0, // nhead_stride_bias, FA without bias nhead_stride_o, 0, // nhead_stride_randval nhead_stride_do, nhead_stride_lse, nhead_stride_dq_acc, nhead_stride_dq, nhead_stride_dk, nhead_stride_dv, 0, // nhead_stride_dbias, FA without dbias batch_stride_q, batch_stride_k, batch_stride_v, 0 , // batch_stride_bias, FA without bias batch_stride_o, 0, // batch_stride_randval batch_stride_do, batch_stride_lse, batch_stride_dq_acc, batch_stride_dq, batch_stride_dk, batch_stride_dv, 0 , // batch_stride_dbias, FA without dbias split_stride_dq_acc, mask.left, mask.right, static_cast(mask.type), p_dropout, p_undrop, drop_seed_offset}; } std::vector mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size const at::Tensor &softmax_lse, // b x h x s softmax logsumexp std::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i std::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i std::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, const bool deterministic, std::optional gen_, std::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif if (is_causal) { window_size_right = 0; } const bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; const int num_heads = sizes[1]; const int head_size = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } mask_info mask; if (is_causal) { std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local } // q, k, v, out had been padded in mha_fwd // dq_, dk_, dv_ are also padded tensor CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(dout, total_q, num_heads, head_size); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); CHECK_SHAPE(dq, total_q, num_heads, head_size); } else { dq = torch::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } else { dk = torch::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { dv = torch::empty_like(v); } const auto traits = get_ck_fmha_varlen_bwd_traits( mask, q_dtype_str, total_q, total_k, batch_size, max_seqlen_q, max_seqlen_k, head_size, num_heads, num_heads_k, is_dropout, alibi_slopes_.has_value(), deterministic); fmha_bwd_launcher launcher(traits); const ck_tile::index_t nsplits = launcher.dq_acc_splits; at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA dk_expanded = torch::empty({total_k, num_heads, head_size}, opts); dv_expanded = torch::empty({total_k, num_heads, head_size}, opts); } else { dk_expanded = dk; dv_expanded = dv; } if(zero_tensors) { dq.zero_(); dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); at::Tensor rng_state; if (rng_state_.has_value()) { rng_state = rng_state_.value(); } else if(is_dropout) { rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); } else { rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (max_seqlen_q > 0) { auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto args = get_ck_fmha_varlen_bwd_args( mask, batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, head_size, q, k, v, cu_seqlens_q, cu_seqlens_k, alibi_slopes_, out, softmax_lse, dout, dq_accum, softmax_d, dq, dk_expanded, dv_expanded, softmax_scale, p_dropout, drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk_expanded.zero_(); dv_expanded.zero_(); softmax_d.zero_(); } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2}); } return { dq, dk, dv, softmax_d }; } ================================================ FILE: csrc/flash_attn_ck/mha_varlen_fwd.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include "flash_common.hpp" #include "fmha_fwd.hpp" #include "mask.hpp" fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, std::string dtype, int head_size, bool has_dropout, bool has_lse, bool enable_alibi) { return fmha_fwd_traits{head_size, head_size, dtype, true, // is_group_mode true, // is_v_rowmajor false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, std::string dtype, int head_size, bool has_lse, bool enable_alibi) { return fmha_fwd_splitkv_traits{head_size, head_size, dtype, true, // is_group_mode true, // is_v_rowmajor false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, false, // do_fp8_static_quant false}; // has_sink } fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, bool has_dropout_randval, const mask_info &mask, // sizes const int b, const int max_seqlen_q, const int h, const int h_k, const int d, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor seqlens_q, const at::Tensor seqlens_k, std::optional &alibi_slopes_, at::Tensor out, at::Tensor softmax_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, std::pair drop_seed_offset) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) // v: (total_k, nheads_k, d) // o: (total_q, nheads, d) // alibi_slopes:(batch, nheads) or (nhead) // lse: (nheads, total_q) // randval: (nheads, total_q, max_seqlen_k) ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); ck_tile::index_t stride_q = q.stride(0); ck_tile::index_t stride_k = k.stride(0); ck_tile::index_t stride_v = v.stride(0); ck_tile::index_t stride_o = out.stride(0); ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0; ck_tile::index_t nhead_stride_q = q.stride(1); ck_tile::index_t nhead_stride_k = k.stride(1); ck_tile::index_t nhead_stride_v = v.stride(1); ck_tile::index_t nhead_stride_o = out.stride(1); ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; ck_tile::index_t batch_stride_q = 0; ck_tile::index_t batch_stride_k = 0; ck_tile::index_t batch_stride_v = 0; ck_tile::index_t batch_stride_o = 0; ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_randval = 0; void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias nullptr, // q_descale_ptr nullptr, // k_descale_ptr nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), seqlens_q.data_ptr(), // seqstart_q_ptr seqlens_k.data_ptr(), // seqstart_k_ptr nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_kv_ptr nullptr, // block_scale_seqstart_q_ptr nullptr, // block_scale_seqstart_k_ptr nullptr, // seqstart_v_scale_ptr nullptr, // sink_ptr total_q, total_k, b, max_seqlen_q, d, // hdim_q d, // hdim_v h, // nhead h_k, // nhead_k softmax_scale, // scale_s 0.0f, // logits_soft_cap stride_q, stride_k, stride_v, stride_alibi_slopes, stride_randval, stride_o, 0, // stride_q_descale 0, // stride_k_descale 0, // stride_v_descale nhead_stride_q, nhead_stride_k, nhead_stride_v, 0, // nhead_stride_bias, FA without bias nhead_stride_randval, nhead_stride_lse, nhead_stride_o, 0, // nhead_stride_q_descale 0, // nhead_stride_k_descale 0, // nhead_stride_v_descale batch_stride_q, batch_stride_k, batch_stride_v, 0, // batch_stride_bias, FA without bias batch_stride_randval, batch_stride_lse, batch_stride_o, 0, // batch_stride_q_descale 0, // batch_stride_k_descale 0, // batch_stride_v_descale mask.left, mask.right, 0, // sink_size static_cast(mask.type), 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset, 0, // block_scale_size_q 0}; // block_scale_size_kv } fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, const mask_info &mask, const int b, const int max_seqlen_q, const int h, const int h_k, const int d, const int page_block_size, const int num_splits, float softmax_scale, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor seqlens_q, const at::Tensor seqlens_k, std::optional &block_table_, std::optional &alibi_slopes_, at::Tensor out, at::Tensor lse, at::Tensor lse_acc, at::Tensor out_acc) { // q: (total_q, nheads, d) // k: (num_blocks, page_block_size, num_heads_k, d) // v: (num_blocks, page_block_size, num_heads_k, d) // o: (total_q, nheads, d) // alibi_slopes:(batch_size, nheads) or (nhead) // lse: (nheads, total_q) // lse_acc: (nheads, split, total_q) // o_acc: (nheads, split, total_q, d) // block_table: (batch_size, max_num_blocks_per_seq) fmha_fwd_splitkv_args args; args.q_ptr = q.data_ptr(); args.k_ptr = k.data_ptr(); args.v_ptr = v.data_ptr(); args.bias_ptr = nullptr; args.lse_acc_ptr = lse_acc.data_ptr(); args.o_acc_ptr = out_acc.data_ptr(); args.lse_ptr = nullptr; args.o_ptr = out.data_ptr(); args.sink_ptr = nullptr; if (block_table_.has_value()) { auto block_table = block_table_.value(); args.block_table_ptr = block_table.data_ptr(); args.batch_stride_block_table = block_table.stride(0); args.page_block_size = page_block_size; } else { args.block_table_ptr = nullptr; args.batch_stride_block_table = 0; args.page_block_size = 0; } args.is_gappy = false; args.cache_batch_idx = nullptr; args.seqstart_q_ptr = seqlens_q.data_ptr(); args.seqstart_k_ptr = seqlens_k.data_ptr(); args.seqlen_k_ptr = nullptr; args.batch = b; args.max_seqlen_q = max_seqlen_q; args.hdim_q = d; args.hdim_v = d; args.nhead_q = h; args.nhead_k = h_k; args.num_splits = num_splits; args.scale_s = softmax_scale; args.scale_p = 1; args.scale_o = 1; args.batch_stride_q = 0; args.stride_q = q.stride(0); args.nhead_stride_q = q.stride(1); args.batch_stride_k = k.stride(0); args.stride_k = k.stride(1); args.nhead_stride_k = k.stride(2); args.batch_stride_v = v.stride(0); args.stride_v = v.stride(1); args.nhead_stride_v = v.stride(2); args.batch_stride_o = 0; args.stride_o = out.stride(0); args.nhead_stride_o = out.stride(1); args.batch_stride_bias = 0; args.stride_bias = 0; args.nhead_stride_bias = 0; args.batch_stride_lse = 0; args.nhead_stride_lse = 0; args.batch_stride_lse_acc = 0; args.nhead_stride_lse_acc = lse_acc.stride(0); args.split_stride_lse_acc = lse_acc.stride(1); args.batch_stride_o_acc = 0; args.nhead_stride_o_acc = out_acc.stride(0); args.split_stride_o_acc = out_acc.stride(1); args.stride_o_acc = out_acc.stride(2); if (has_lse) { args.lse_ptr = lse.data_ptr(); args.batch_stride_lse = 0; args.nhead_stride_lse = lse.stride(0); } if (alibi_slopes_.has_value()) { auto alibi_slopes = alibi_slopes_.value(); CHECK_DEVICE(alibi_slopes); TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); args.bias_ptr = alibi_slopes.data_ptr(); args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } args.window_size_left = mask.left; args.window_size_right = mask.right; args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; } std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. std::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 std::optional & /*seqused_k*/, std::optional &/*leftpad_k_*/, // batch_size std::optional &block_table_, // batch_size x max_num_blocks_per_seq std::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, const float /*softcap*/, const bool return_dropout_randval, std::optional gen_) { auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); at::Tensor block_table; const bool paged_KV = block_table_.has_value(); if (paged_KV) { block_table = block_table_.value(); CHECK_DEVICE(block_table); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); } TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); CHECK_CONTIGUOUS(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_k); const auto sizes = q.sizes(); const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case // TODO // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int total_q = q.size(0); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (window_size_left >= max_seqlen_k) { window_size_left = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; } mask_info mask; if (is_causal) { // Causal is the special case where window_size_right == 0 and window_size_left < 0. window_size_right = 0; std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0"; mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual } else if (window_size_left == -1 && window_size_right == -1) { mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask } else { // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right); mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local } CHECK_SHAPE(q, total_q, num_heads, head_size); if (!paged_KV) { const int total_k = k.size(0); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); } else { CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size); CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); } CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, total_q, num_heads, head_size); } else { out = torch::empty_like(q); } // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); bool has_lse = true; bool has_dropout = p_dropout > 0.0f; if (has_dropout) TORCH_CHECK(!paged_KV, "Paged KV does not support dropout"); at::Tensor softmax_lse; // TODO - check gradient, only training require lse softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32)); at::Tensor p; if (return_dropout_randval) { TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0"); p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8)); } else { p = torch::empty({ 0 }, opts); } if (zero_tensors) { out.zero_(); softmax_lse.fill_(-std::numeric_limits::infinity()); if (return_dropout_randval) {p.zero_();} } int num_splits = 0; num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits); TORCH_CHECK(num_splits > 0, "num_splits should greater than 0"); TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported"); auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat)); auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size}, opts.dtype(at::kFloat)); int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } if (max_seqlen_k > 0) { #ifdef HIPIFY_V2 auto stream = at::cuda::getCurrentCUDAStream().stream(); #else auto stream = at::cuda::getCurrentHIPStream().stream(); #endif ck_tile::stream_config stream_config{stream}; if (paged_KV) { auto traits = get_ck_fmha_varlen_fwd_splitkv_traits( mask, q_dtype_str, head_size, has_lse, alibi_slopes_.has_value()); auto args = get_ck_fmha_varlen_fwd_splitkv_args( has_lse, mask, batch_size, max_seqlen_q, num_heads, num_heads_k, head_size, page_block_size, num_splits, softmax_scale, q, k, v, cu_seqlens_q, cu_seqlens_k, block_table_, alibi_slopes_, out, softmax_lse, softmax_lse_accum, out_accum); float t = fmha_fwd_splitkv(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd_splitkv"); } else { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto traits = get_ck_fmha_varlen_fwd_traits( mask, q_dtype_str, head_size, has_dropout, has_lse, alibi_slopes_.has_value()); auto args = get_ck_fmha_varlen_fwd_args( has_lse, return_dropout_randval, mask, batch_size, max_seqlen_q, num_heads, num_heads_k, head_size, q, k, v, cu_seqlens_q, cu_seqlens_k, alibi_slopes_, out, softmax_lse, p, softmax_scale, p_dropout, drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } return {out, softmax_lse, p, rng_state}; } ================================================ FILE: csrc/fused_dense_lib/README.md ================================================ This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu (forward and backward), adapted from Apex's [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before this doesn't have the best matmul + bias + gelu performance for bfloat16. It has only been tested on A100s. ```sh cd csrc/fused_dense_lib && pip install . ``` ================================================ FILE: csrc/fused_dense_lib/fused_dense.cpp ================================================ // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp // We make it work for bfloat16 #include #include #include #include #include #include #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ using scalar_t = at::Half; \ __VA_ARGS__(); \ break; \ } \ case at::ScalarType::BFloat16: { \ using scalar_t = at::BFloat16; \ __VA_ARGS__(); \ break; \ } \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } template int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize); template int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize); std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) { int64_t batch_size = input.size(0); int64_t in_features = input.size(1); int64_t out_features = d_output.size(1); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == d_output.dtype()); TORCH_CHECK(input.is_cuda()); TORCH_CHECK(d_output.is_cuda()); TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(d_output.is_contiguous()); CHECK_SHAPE(input, batch_size, in_features); CHECK_SHAPE(d_output, batch_size, out_features); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{input.device()}; // create output/workspace tensor auto opts = input.options(); auto d_weight = at::empty({out_features, in_features}, opts); at::Tensor d_bias; if (has_d_bias) { #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 d_bias = d_output.view({-1, out_features}).sum(0, false); #else d_bias = at::empty({out_features}, opts); #endif } // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { auto result = linear_bias_wgrad_cuda( input.data_ptr(), d_output.data_ptr(), in_features, batch_size, out_features, d_weight.data_ptr(), has_d_bias ? d_bias.data_ptr() : nullptr, (void*) (lt_workspace.data_ptr()), workspaceSize); TORCH_CHECK(result == 0, "linear_bias_wgrad failed."); }); return {d_weight, d_bias}; } std::vector linear_act_forward(at::Tensor input, at::Tensor weight, std::optional bias_, bool is_gelu, bool save_pre_act, int heuristic) { int64_t batch_size = input.size(0); int64_t in_features = input.size(1); int64_t out_features = weight.size(0); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); TORCH_CHECK(input.dtype() == weight.dtype()); TORCH_CHECK(input.is_cuda()); TORCH_CHECK(weight.is_cuda()); TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(weight.is_contiguous()); CHECK_SHAPE(input, batch_size, in_features); CHECK_SHAPE(weight, out_features, in_features); if (bias_.has_value()) { auto bias = bias_.value(); TORCH_CHECK(bias.dtype() == input.dtype()); TORCH_CHECK(bias.is_cuda()); TORCH_CHECK(bias.is_contiguous()); CHECK_SHAPE(bias, out_features); } // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{input.device()}; // create output/workspace tensor auto opts = input.options(); auto output = at::empty({batch_size, out_features}, opts); at::Tensor pre_act; // If ReLU, cuBlasLT stores a bit-mask (1 bit per element) if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8}, is_gelu ? opts : opts.dtype(torch::kUInt8)); } // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] { auto result = linear_act_forward_cuda( input.data_ptr(), weight.data_ptr(), bias_.has_value()? bias_.value().data_ptr() : nullptr, in_features, batch_size, out_features, is_gelu, heuristic, output.data_ptr(), save_pre_act ? pre_act.data_ptr() : nullptr, (void*) (lt_workspace.data_ptr()), workspaceSize); TORCH_CHECK(result == 0, "linear_act_forward failed."); }); std::vector result = {output}; if (save_pre_act) { result.push_back(pre_act); }; return result; } std::vector bias_act_linear_dgrad_bgrad( at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic ) { int64_t batch_size = d_output.size(0); int64_t out_features = d_output.size(1); int64_t in_features = weight.size(1); TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16); TORCH_CHECK(weight.dtype() == d_output.dtype()); TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8)); TORCH_CHECK(weight.is_cuda()); TORCH_CHECK(d_output.is_cuda()); TORCH_CHECK(pre_act.is_cuda()); TORCH_CHECK(weight.is_contiguous()); TORCH_CHECK(d_output.is_contiguous()); TORCH_CHECK(pre_act.is_contiguous()); CHECK_SHAPE(weight, out_features, in_features); CHECK_SHAPE(d_output, batch_size, out_features); // If ReLU, cuBlasLT stores a bit-mask (1 bit per element) CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{weight.device()}; // create output/workspace tensor auto opts = weight.options(); auto d_bias = at::empty({in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts); // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] { auto result = bias_act_linear_dgrad_bgrad_cuda( weight.data_ptr(), d_output.data_ptr(), pre_act.data_ptr(), in_features, batch_size, out_features, is_gelu, heuristic, d_input.data_ptr(), d_bias.data_ptr(), (void*) (lt_workspace.data_ptr()), workspaceSize); TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed."); }); return {d_input, d_bias}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad"); m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward"); m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad"); } ================================================ FILE: csrc/fused_dense_lib/fused_dense_cuda.cu ================================================ // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu #include #include #include #include #include #include #include /* Includes, cuda */ #include #include #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #include #endif // FP16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, const float* alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const float* beta, at::Half* C, int64_t ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } // BF16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemm_bias( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, const float* alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const float* beta, at::BFloat16* C, int64_t ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_16BF, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, const Dtype* bias, Dtype* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize ) { static_assert(std::is_same::value || std::is_same::value, "gemm_bias_act_lt only supports fp16 and bf16"); bool save_pre_act = pre_act != nullptr; float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; // constexpr int requestedAlgoCount = 1; // cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtEpilogue_t epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (save_pre_act) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); } if (bias != nullptr) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS : CUBLASLT_EPILOGUE_RELU_BIAS); } else { epilogue = is_gelu ? (save_pre_act ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU) : (save_pre_act ? CUBLASLT_EPILOGUE_RELU_AUX : CUBLASLT_EPILOGUE_RELU); } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); // ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, // &heuristicResult.algo, // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos &heuristicResult[heuristic].algo, // NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const at::Half* bias, at::Half* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_bias_act_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const at::BFloat16* bias, at::BFloat16* C, int64_t ldc, void* pre_act, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, Dtype* C, int64_t ldc, Dtype* bgrad, void *lt_workspace, size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (bgrad != nullptr) { status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } epilogue = CUBLASLT_EPILOGUE_BGRADB; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, //&heuristicResult.algo, NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, at::Half* C, int64_t ldc, at::Half* bgrad, void *lt_workspace, size_t workspaceSize); template int gemm_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, at::BFloat16* C, int64_t ldc, at::BFloat16* bgrad, void *lt_workspace, size_t workspaceSize); template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const Dtype* A, int64_t lda, const Dtype* B, int64_t ldb, const void* pre_act, Dtype* C, int64_t ldc, Dtype* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_dact_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; cudaDataType_t abcType = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasLtMatmulDescOpaque_t operationDesc = {}; cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; cublasLtMatmulPreferenceOpaque_t preference = {}; int returnedResults = 0; constexpr int requestedAlgoCount = 5; cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; cublasLtEpilogue_t epilogue = is_gelu ? CUBLASLT_EPILOGUE_DGELU_BGRAD : CUBLASLT_EPILOGUE_DRELU_BGRAD; // Create operation descriptor; see cublasLtMatmulDescAttributes_t // for details about defaults; here we just set the transforms for // A and B. status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_act, sizeof(pre_act)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); if (status != CUBLAS_STATUS_SUCCESS) { goto CLEANUP; } // Create matrix descriptors. Not setting any extra attributes. status = cublasLtMatrixLayoutInit( &Adesc, abcType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit( &Bdesc, abcType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatrixLayoutInit(&Cdesc, abcType, m, n, ldc); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // Create preference handle; In general, extra attributes can be // used here to disable tensor ops or to make sure algo selected // will work with badly aligned A, B, C. However, for simplicity // here we assume A,B,C are always well aligned (e.g., directly // come from cudaMalloc) status = cublasLtMatmulPreferenceInit(&preference); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; status = cublasLtMatmulPreferenceSetAttribute( &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; // We just need the best available heuristic to try and run matmul. // There is no guarantee that this will work. For example, if A is // badly aligned, you can request more (e.g. 32) algos and try to // run them one by one until something works. status = cublasLtMatmulAlgoGetHeuristic( ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; if (returnedResults == 0) { status = CUBLAS_STATUS_NOT_SUPPORTED; goto CLEANUP; } status = cublasLtMatmul(ltHandle, &operationDesc, &alpha, A, &Adesc, B, &Bdesc, &beta, C, &Cdesc, C, &Cdesc, //&heuristicResult.algo, &heuristicResult[heuristic].algo, // NULL, lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); CLEANUP: // Descriptors are no longer needed as all GPU work was already // enqueued. return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; } template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::Half* A, int64_t lda, const at::Half* B, int64_t ldb, const void* pre_act, at::Half* C, int64_t ldc, at::Half* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); template int gemm_dact_bgradb_lt( cublasOperation_t transa, cublasOperation_t transb, int64_t m, int64_t n, int64_t k, float alpha, const at::BFloat16* A, int64_t lda, const at::BFloat16* B, int64_t ldb, const void* pre_act, at::BFloat16* C, int64_t ldc, at::BFloat16* bgrad, bool is_gelu, int heuristic, void *lt_workspace, size_t workspaceSize); #endif template int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; const float beta_zero = 0.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bgradb_lt( // (cublasLtHandle_t)handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, alpha, input, in_features, d_output, out_features, d_weight, in_features, d_bias, lt_workspace, workspaceSize); #endif if (status != 0){ cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); status = gemm_bias( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_features, out_features, batch_size, &alpha, input, in_features, d_output, out_features, &beta_zero, d_weight, in_features); // TD [2023-01-17]: I can't call Pytorch's gemm for now, due to linking error // https://discuss.pytorch.org/t/how-can-i-use-the-function-at-gemm-float/95341 // at::cuda::blas::gemm( // 'N', // 'T', // in_features, // out_features, // batch_size, // alpha, // input, // in_features, // d_output, // out_features, // beta_zero, // d_weight, // in_features); } return status; } template int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) { int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bias_act_lt( CUBLAS_OP_T, CUBLAS_OP_N, out_features, batch_size, in_features, /*alpha=*/1.0, weight, in_features, input, in_features, bias, output, out_features, pre_act, is_gelu, heuristic, lt_workspace, workspaceSize); return status; #else return 1; #endif } template int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_dact_bgradb_lt( CUBLAS_OP_N, CUBLAS_OP_N, in_features, batch_size, out_features, alpha, weight, in_features, d_output, out_features, pre_act, d_input, in_features, d_bias, is_gelu, heuristic, lt_workspace, workspaceSize); #endif return status; } template int linear_bias_wgrad_cuda(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); template int linear_bias_wgrad_cuda(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize); template int linear_act_forward_cuda(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template int linear_act_forward_cuda(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template int bias_act_linear_dgrad_bgrad_cuda(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); template int bias_act_linear_dgrad_bgrad_cuda(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize); ================================================ FILE: csrc/fused_dense_lib/setup.py ================================================ import os import subprocess from packaging.version import parse, Version import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_version def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.2"): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] return nvcc_extra_args setup( name='fused_dense_lib', ext_modules=[ CUDAExtension( name='fused_dense_lib', sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], extra_compile_args={ 'cxx': ['-O3',], 'nvcc': append_nvcc_threads(['-O3']) } ) ], cmdclass={ 'build_ext': BuildExtension }) ================================================ FILE: csrc/layer_norm/README.md ================================================ This CUDA extension implements fused dropout + residual + LayerNorm, building on Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). Major changes: - Add dropout and residual. - Make it work for both pre-norm and post-norm architecture. - Support more hidden dimensions (all dimensions divisible by 8, up to 8192). - Implement RMSNorm as an option. - Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM). If you want to use it for dimensions larger than 8k, please file an issue. This extension has only been tested on A100s. ```sh cd csrc/layer_norm && pip install . ``` As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py). ================================================ FILE: csrc/layer_norm/ln.h ================================================ #pragma once #include #include #include #ifdef OLD_GENERATOR_PATH #include #else #include #endif namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct LaunchParams{ size_t elts_per_thread; size_t workspace_bytes; size_t barrier_size; cudaDeviceProp * props; cudaStream_t stream; Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct ParamsBase { ParamsBase() : ctas_per_col(0) , rows(0) , cols(0) , x(nullptr) , mu(nullptr) , rs(nullptr) , gamma(nullptr) , gamma1(nullptr) , rowscale(nullptr) , colscale(nullptr) , dropout_keep_p(1.f) , dropout_scale(1.f) , is_rms_norm(false) , workspace(nullptr) , barrier(nullptr) { } // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. int ctas_per_col; // Input is interpreted as matrix. We normalize across columns. int rows; int cols; // Common data pointers. void *x0; void *x1; void *residual; void *x; void *dmask; void *dmask1; void *mu; void *rs; void *gamma; void *gamma1; void *rowscale; void *colscale; void *x0_subset; void *z_subset; float inverse_cols; float dropout_keep_p; float dropout_scale; float rowscale_const; bool is_rms_norm; // Multi-CTA workspace in gmem. void *workspace; // Multi-CTA sync barriers in gmem. int *barrier; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct FwdParams : public ParamsBase { FwdParams() : ParamsBase() , z(nullptr) , z1(nullptr) , beta(nullptr) , beta1(nullptr) , epsilon(0.f) { } // Output of LN FWD. void *z; void *z1; void *beta; void *beta1; float epsilon; // Random state. at::PhiloxCudaState philox_args; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct BwdParams : public ParamsBase { BwdParams() : ParamsBase() , dz(nullptr) , dz1(nullptr) , dx(nullptr) , dbeta_part(nullptr) , dgamma_part(nullptr) , dbeta1_part(nullptr) , dgamma1_part(nullptr) , dcolscale_part(nullptr) , dx0(nullptr) , dx1(nullptr) , dresidual(nullptr) , dbeta(nullptr) , dgamma(nullptr) , dbeta1(nullptr) , dgamma1(nullptr) , dcolscale(nullptr) { } // Input: gradient wrt. LN FWD output. void *dz; void *dz1; // Input: gradient wrt residual. void *dx; // Workspace for Wgrad pre-reduction. void *dbeta_part; void *dgamma_part; void *dbeta1_part; void *dgamma1_part; void *dcolscale_part; // Output: Dgrad. void *dx0; void *dx1; void *dresidual; // Output: Wgrad. void *dbeta; void *dgamma; void *dbeta1; void *dgamma1; void *dcolscale; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using FwdFunction = std::function&, const bool)>; using BwdFunction = std::function&, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; using BwdRegistry = std::unordered_map; extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeId{}; template<> struct TypeId{ constexpr static uint32_t Value = 0; }; template<> struct TypeId{ constexpr static uint32_t Value = 1; }; template<> struct TypeId{ constexpr static uint32_t Value = 2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Type2Key{ constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct WeightType2Key : public Type2Key{}; template struct InputType2Key : public Type2Key{}; template struct ResidualType2Key : public Type2Key{}; template struct OutputType2Key : public Type2Key{}; template struct ComputeType2Key : public Type2Key{}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Types2Key{ constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; constexpr static inline uint64_t get(const uint64_t hidden_size){ constexpr uint64_t type_key = Value; return (type_key << 32) | hidden_size; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdRegistrar{ FwdRegistrar(FwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); FWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdRegistrar{ BwdRegistrar(BwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); BWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdParallelRegistrar{ FwdParallelRegistrar(FwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); PARALLEL_FWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BwdParallelRegistrar{ BwdParallelRegistrar(BwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); PARALLEL_BWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: csrc/layer_norm/ln_api.cpp ================================================ #include #include "ATen/cuda/CUDAContext.h" #include #include "ln.h" /* Supported Type combinations: input residual compute weights output ============================================ fp32 fp32 fp32 fp32 fp32 fp16 fp32 fp32 fp32 fp16 fp16 fp16 fp32 fp32 fp16 bf16 fp32 fp32 fp32 bf16 bf16 bf16 fp32 fp32 bf16 fp16 fp16 fp32 fp16 fp16 bf16 bf16 fp32 bf16 bf16 Remarks: Output type = Input type Compute always in FP32 */ namespace layer_norm { // Create registries and provide runtime versions of config hash functions. FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// uint32_t get_type_id(torch::Dtype dtype){ if( dtype == torch::kFloat16 ) { return TypeId::Value; } else if( dtype == torch::kBFloat16 ) { return TypeId::Value; } else if( dtype == torch::kFloat32 ) { return TypeId::Value; } else { TORCH_CHECK(false, "Type not supported: ", dtype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { using namespace layer_norm; uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8); uint64_t launcher_key = (type_key << 32) | hidden_size; return launcher_key; } } // namespace layer_norm //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); if( iter != layer_norm::FWD_FUNCS.end() ) { return iter->second; } else { TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); if( iter != layer_norm::BWD_FUNCS.end() ) { return iter->second; } else { TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::FwdFunction & get_parallel_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::PARALLEL_FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); if( iter != layer_norm::PARALLEL_FWD_FUNCS.end() ) { return iter->second; } else { TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// layer_norm::BwdFunction & get_parallel_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { auto iter = layer_norm::PARALLEL_BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); if( iter != layer_norm::PARALLEL_BWD_FUNCS.end() ) { return iter->second; } else { TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); } } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size std::optional &residual_, // Residual: BxSxhidden_size const at::Tensor &gamma, // hidden_size std::optional &beta_, // hidden_size std::optional &rowscale_, // BxS std::optional &colscale_, // hidden_size std::optional &x0_subset_, // BxS std::optional &z_subset_, // BxS const float dropout_p, const float epsilon, const float rowscale_const, const int64_t z_numrows, std::optional gen_, bool residual_in_fp32=false, bool is_rms_norm=false ) { auto itype = x0.scalar_type(); auto rtype = residual_.has_value() ? residual_.value().scalar_type() : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); auto wtype = gamma.scalar_type(); auto otype = itype; auto ctype = torch::kFloat32; auto mtype = torch::kUInt8; TORCH_CHECK(x0.is_cuda()); TORCH_CHECK(gamma.is_cuda()); TORCH_CHECK(x0.is_contiguous()); // c10::IntArrayRef does not own the storage, so we need to construct a vector. // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because // blah is then deallocated. std::vector sizes_vec {!x0_subset_.has_value() ? x0.size(0) : x0_subset_.value().size(0), x0.size(1)}; auto sizes = c10::IntArrayRef(sizes_vec); TORCH_CHECK(x0.dim() == 2); TORCH_CHECK(sizes.size() == 2); const int rows = sizes[0]; const int cols = sizes[1]; auto hidden_size = gamma.numel(); TORCH_CHECK(hidden_size == cols); if (beta_.has_value()) { auto beta = beta_.value(); TORCH_CHECK(beta.dtype() == wtype); TORCH_CHECK(beta.is_cuda()); TORCH_CHECK(beta.is_contiguous()); TORCH_CHECK(beta.sizes() == gamma.sizes()); } if (residual_.has_value()) { auto residual = residual_.value(); TORCH_CHECK(residual.is_cuda()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.sizes() == sizes); } if (rowscale_.has_value()) { auto rowscale = rowscale_.value(); TORCH_CHECK(rowscale.is_cuda()); TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(rowscale.dtype() == itype); } if (colscale_.has_value()) { auto colscale = colscale_.value(); TORCH_CHECK(colscale.is_cuda()); TORCH_CHECK(colscale.is_contiguous()); TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); TORCH_CHECK(colscale.dtype() == wtype); } if (x0_subset_.has_value()) { auto x0_subset = x0_subset_.value(); TORCH_CHECK(x0_subset.is_cuda()); TORCH_CHECK(x0_subset.is_contiguous()); TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(x0_subset.dtype() == torch::kInt32); TORCH_CHECK(z_subset_.has_value()); auto z_subset = z_subset_.value(); TORCH_CHECK(z_subset.is_cuda()); TORCH_CHECK(z_subset.is_contiguous()); TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(z_subset.dtype() == torch::kInt32); } TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); TORCH_CHECK(epsilon >= 0.f); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{x0.device()}; auto opts = x0.options(); bool save_x = residual_.has_value() || (dropout_p > 0.f) || rowscale_.has_value() || colscale_.has_value() || x0_subset_.has_value() || (itype != rtype); at::Tensor x; if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } at::Tensor dmask; if (dropout_p > 0.f) { dmask = torch::empty(x0.sizes(), opts.dtype(mtype)); }; auto z = torch::empty(z_subset_.has_value() ? c10::IntArrayRef{z_numrows, cols} : sizes, opts.dtype(otype)); auto mu = torch::empty({ rows }, opts.dtype(ctype)); auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); layer_norm::LaunchParams launch_params; launch_params.props = at::cuda::getCurrentDeviceProperties(); launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); // Request the kernel launcher. auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); // Set the kernel runtime parameters. layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; params.x0 = x0.data_ptr(); params.x = save_x ? x.data_ptr() : nullptr; params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr; params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.beta = beta_.has_value() ? beta_.value().data_ptr() : nullptr; params.z = z.data_ptr(); params.epsilon = epsilon; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); params.rowscale_const = rowscale_const; params.is_rms_norm = is_rms_norm; // Query the kernel-specific launch parameters. launcher(launch_params, true); at::Tensor workspace, barrier; if (dropout_p > 0.f) { // number of times random will be generated per thread, to offset philox counter in thc random // state int64_t counter_offset = launch_params.elts_per_thread; // See Note [Acquire lock when using random generators] { std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); } } if( launch_params.barrier_size > 0 ) { auto options = x0.options(); barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } // Launch the kernel. launcher(launch_params, false); return { z, x, dmask, mu, rsigma }; } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size std::optional &dx_, // BxSxhidden_size const at::Tensor &x, // BxSxhidden_size std::optional &x0_, // BxSxhidden_size std::optional &dmask_, // BxSxhidden_size const at::Tensor &mu, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma, // hidden_size std::optional &rowscale_, // BxS std::optional &colscale_, // hidden_size std::optional &x0_subset_, // BxS std::optional &z_subset_, // BxS const float dropout_p, const float rowscale_const, const int64_t x0_numrows, const bool has_residual, bool is_rms_norm=false ) { auto itype = dz.scalar_type(); auto rtype = x.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = itype; auto ctype = torch::kFloat32; auto mtype = torch::kUInt8; if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } TORCH_CHECK(dz.dtype() == otype); TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(rsigma.dtype() == ctype); TORCH_CHECK(x.is_cuda()); TORCH_CHECK(dz.is_cuda()); TORCH_CHECK(mu.is_cuda()); TORCH_CHECK(rsigma.is_cuda()); TORCH_CHECK(gamma.is_cuda()); TORCH_CHECK(x.is_contiguous()); TORCH_CHECK(dz.is_contiguous()); auto sizes = x.sizes(); TORCH_CHECK(sizes.size() == 2); auto rows = sizes[0]; auto cols = sizes[1]; TORCH_CHECK(dz.dim() == 2); TORCH_CHECK(dz.size(1) == cols); auto hidden_size = gamma.numel(); TORCH_CHECK(hidden_size == cols); // c10::IntArrayRef does not own the storage, so we need to construct a vector. // Otherwise just constructing IntArrayRef({blah}) will cause uninitialized memory because // blah is then deallocated. std::vector x0_sizes_vec {!x0_subset_.has_value() ? rows : x0_numrows, cols}; auto x0_sizes = c10::IntArrayRef(x0_sizes_vec); if (dx_.has_value()) { auto dx = dx_.value(); TORCH_CHECK(dx.dtype() == rtype); TORCH_CHECK(dx.is_cuda()); TORCH_CHECK(dx.is_contiguous()); TORCH_CHECK(dx.sizes() == sizes); } if (dmask_.has_value()) { auto dmask = dmask_.value(); TORCH_CHECK(dmask.dtype() == mtype); TORCH_CHECK(dmask.is_cuda()); TORCH_CHECK(dmask.is_contiguous()); TORCH_CHECK(dmask.sizes() == x0_sizes); } if (rowscale_.has_value()) { auto rowscale = rowscale_.value(); TORCH_CHECK(rowscale.is_cuda()); TORCH_CHECK(rowscale.is_contiguous()); TORCH_CHECK(rowscale.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(rowscale.dtype() == itype); } if (colscale_.has_value()) { auto colscale = colscale_.value(); TORCH_CHECK(colscale.is_cuda()); TORCH_CHECK(colscale.is_contiguous()); TORCH_CHECK(colscale.sizes() == c10::IntArrayRef{cols}); TORCH_CHECK(colscale.dtype() == wtype); TORCH_CHECK(x0_.has_value()); auto x0 = x0_.value(); TORCH_CHECK(x0.is_cuda()); TORCH_CHECK(x0.is_contiguous()); TORCH_CHECK(x0.sizes() == x0_sizes); TORCH_CHECK(x0.dtype() == itype); } if (x0_subset_.has_value()) { auto x0_subset = x0_subset_.value(); TORCH_CHECK(x0_subset.is_cuda()); TORCH_CHECK(x0_subset.is_contiguous()); TORCH_CHECK(x0_subset.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(x0_subset.dtype() == torch::kInt32); TORCH_CHECK(z_subset_.has_value()); auto z_subset = z_subset_.value(); TORCH_CHECK(z_subset.is_cuda()); TORCH_CHECK(z_subset.is_contiguous()); TORCH_CHECK(z_subset.sizes() == c10::IntArrayRef{rows}); TORCH_CHECK(z_subset.dtype() == torch::kInt32); } TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); TORCH_CHECK(mu.numel() == rows); TORCH_CHECK(mu.sizes() == rsigma.sizes()); TORCH_CHECK(gamma.numel() == cols); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{dz.device()}; auto opts = x.options(); auto dx0 = torch::empty(x0_sizes, opts.dtype(itype)); at::Tensor dresidual; if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); at::Tensor dcolscale; if (colscale_.has_value()) { dcolscale = torch::empty_like(colscale_.value()); } layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); launch_params.props = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.colscale = colscale_.has_value() ? colscale_.value().data_ptr() : nullptr; launch_params.params.x0_subset = x0_subset_.has_value() ? x0_subset_.value().data_ptr() : nullptr; launch_params.params.z_subset = z_subset_.has_value() ? z_subset_.value().data_ptr() : nullptr; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); launcher(launch_params, true); auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); at::Tensor dcolscale_part; if (colscale_.has_value()) { dcolscale_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); } at::Tensor workspace, barrier; layer_norm::BwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; params.x = x.data_ptr(); params.x0 = x0_.has_value() ? x0_.value().data_ptr() : nullptr; params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.dz = dz.data_ptr(); params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; params.dx0 = dx0.data_ptr(); params.dbeta = dbeta.data_ptr(); params.dgamma = dgamma.data_ptr(); params.dcolscale = colscale_.has_value() ? dcolscale.data_ptr() : nullptr; params.dbeta_part = dbeta_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr(); params.dcolscale_part = colscale_.has_value() ? dcolscale_part.data_ptr() : nullptr; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); params.rowscale_const = rowscale_const; params.is_rms_norm = is_rms_norm; if( launch_params.barrier_size > 0 ) { // TODO Any way to avoid this? barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } launcher(launch_params, false); std::vector result = { dx0, dresidual, dgamma, dbeta, dgamma_part, dbeta_part }; if (colscale_.has_value()) { result.push_back(dcolscale); result.push_back(dcolscale_part); } return result; } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_parallel_residual_fwd( const at::Tensor &x0, // Input: BxSxhidden_size std::optional &x1_, // Input: BxSxhidden_size std::optional &residual_, // Residual: BxSxhidden_size const at::Tensor &gamma0, // hidden_size std::optional &beta0_, // hidden_size std::optional &gamma1_, // hidden_size std::optional &beta1_, // hidden_size const float dropout_p, const float epsilon, std::optional gen_, bool residual_in_fp32=false, bool is_rms_norm=false ) { auto itype = x0.scalar_type(); auto rtype = residual_.has_value() ? residual_.value().scalar_type() : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); auto wtype = gamma0.scalar_type(); auto otype = itype; auto ctype = torch::kFloat32; auto mtype = torch::kUInt8; TORCH_CHECK(x0.is_cuda()); TORCH_CHECK(gamma0.is_cuda()); TORCH_CHECK(x0.is_contiguous()); const auto sizes = x0.sizes(); TORCH_CHECK(x0.dim() == 2); const int rows = sizes[0]; const int cols = sizes[1]; auto hidden_size = gamma0.numel(); TORCH_CHECK(hidden_size == cols); if (x1_.has_value()) { auto x1 = x1_.value(); TORCH_CHECK(x1.is_cuda()); TORCH_CHECK(x1.is_contiguous()); TORCH_CHECK(x1.sizes() == sizes); } if (residual_.has_value()) { auto residual = residual_.value(); TORCH_CHECK(residual.is_cuda()); TORCH_CHECK(residual.is_contiguous()); TORCH_CHECK(residual.sizes() == sizes); } if (beta0_.has_value()) { auto beta0 = beta0_.value(); TORCH_CHECK(beta0.dtype() == wtype); TORCH_CHECK(beta0.is_cuda()); TORCH_CHECK(beta0.is_contiguous()); TORCH_CHECK(beta0.sizes() == gamma0.sizes()); } if (gamma1_.has_value()) { auto gamma1 = gamma1_.value(); TORCH_CHECK(gamma1.dtype() == wtype); TORCH_CHECK(gamma1.is_cuda()); TORCH_CHECK(gamma1.is_contiguous()); TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); } if (beta1_.has_value()) { auto beta1 = beta1_.value(); TORCH_CHECK(beta1.dtype() == wtype); TORCH_CHECK(beta1.is_cuda()); TORCH_CHECK(beta1.is_contiguous()); TORCH_CHECK(beta1.sizes() == gamma0.sizes()); } TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); TORCH_CHECK(epsilon >= 0.f); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{x0.device()}; auto opts = x0.options(); bool save_x = residual_.has_value() || x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); at::Tensor x; if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } at::Tensor dmask0, dmask1; if (dropout_p > 0.f) { dmask0 = torch::empty(x0.sizes(), opts.dtype(mtype)); if (x1_.has_value()) { dmask1 = torch::empty(x0.sizes(), opts.dtype(mtype)); } }; auto z0 = torch::empty(sizes, opts.dtype(otype)); at::Tensor z1; if (gamma1_.has_value()) { z1 = torch::empty(sizes, opts.dtype(otype)); } auto mu = torch::empty({ rows }, opts.dtype(ctype)); auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); layer_norm::LaunchParams launch_params; launch_params.props = at::cuda::getCurrentDeviceProperties(); launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.residual = residual_.has_value() ? residual_.value().data_ptr() : nullptr; auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); // Request the kernel launcher. auto launcher = get_parallel_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); // Set the kernel runtime parameters. layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; params.x0 = x0.data_ptr(); params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; params.x = save_x ? x.data_ptr() : nullptr; params.dmask = dropout_p > 0.f ? dmask0.data_ptr() : nullptr; params.dmask1 = (dropout_p > 0.f && x1_.has_value()) ? dmask1.data_ptr() : nullptr; params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma0.data_ptr(); params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; params.beta = beta0_.has_value() ? beta0_.value().data_ptr() : nullptr; params.beta1 = beta1_.has_value() ? beta1_.value().data_ptr() : nullptr; params.z = z0.data_ptr(); params.z1 = gamma1_.has_value() ? z1.data_ptr() : nullptr; params.epsilon = epsilon; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); params.is_rms_norm = is_rms_norm; // Query the kernel-specific launch parameters. launcher(launch_params, true); at::Tensor workspace, barrier; if (dropout_p > 0.f) { // number of times random will be generated per thread, to offset philox counter in thc random // state int64_t counter_offset = 2 * launch_params.elts_per_thread; // See Note [Acquire lock when using random generators] { std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); } } if( launch_params.barrier_size > 0 ) { auto options = x0.options(); barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } // Launch the kernel. launcher(launch_params, false); return { z0, z1, x, dmask0, dmask1, mu, rsigma }; } //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector dropout_add_ln_parallel_residual_bwd( const at::Tensor &dz0, // BxSxhidden_size std::optional &dz1_, // BxSxhidden_size std::optional &dx_, // BxSxhidden_size const at::Tensor &x, // BxSxhidden_size std::optional &dmask0_, // BxSxhidden_size std::optional &dmask1_, // BxSxhidden_size const at::Tensor &mu, // BxS, FP32! const at::Tensor &rsigma, // BxS, FP32! const at::Tensor &gamma0, // hidden_size std::optional &gamma1_, // hidden_size const float dropout_p, const bool has_x1, const bool has_residual, bool is_rms_norm=false ) { auto itype = dz0.scalar_type(); auto rtype = x.scalar_type(); auto wtype = gamma0.scalar_type(); auto otype = itype; auto ctype = torch::kFloat32; auto mtype = torch::kUInt8; if (dropout_p > 0.f) { TORCH_CHECK(dmask0_.has_value()); } TORCH_CHECK(dz0.dtype() == otype); TORCH_CHECK(dz0.dtype() == otype); TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(rsigma.dtype() == ctype); TORCH_CHECK(x.is_cuda()); TORCH_CHECK(dz0.is_cuda()); TORCH_CHECK(mu.is_cuda()); TORCH_CHECK(rsigma.is_cuda()); TORCH_CHECK(gamma0.is_cuda()); TORCH_CHECK(x.is_contiguous()); TORCH_CHECK(dz0.is_contiguous()); auto sizes = x.sizes(); TORCH_CHECK(sizes.size() == 2); auto rows = sizes[0]; auto cols = sizes[1]; TORCH_CHECK(dz0.dim() == 2); TORCH_CHECK(dz0.size(1) == cols); auto hidden_size = gamma0.numel(); TORCH_CHECK(hidden_size == cols); if (dz1_.has_value()) { auto dz1 = dz1_.value(); TORCH_CHECK(dz1.dtype() == otype); TORCH_CHECK(dz1.is_cuda()); TORCH_CHECK(dz1.is_contiguous()); TORCH_CHECK(dz1.sizes() == sizes); TORCH_CHECK(gamma1_.has_value()); auto gamma1 = gamma1_.value(); TORCH_CHECK(gamma1.dtype() == wtype); TORCH_CHECK(gamma1.is_cuda()); TORCH_CHECK(gamma1.is_contiguous()); TORCH_CHECK(gamma1.sizes() == gamma0.sizes()); } if (dx_.has_value()) { auto dx = dx_.value(); TORCH_CHECK(dx.dtype() == rtype); TORCH_CHECK(dx.is_cuda()); TORCH_CHECK(dx.is_contiguous()); TORCH_CHECK(dx.sizes() == sizes); } if (dmask0_.has_value()) { auto dmask0 = dmask0_.value(); TORCH_CHECK(dmask0.dtype() == mtype); TORCH_CHECK(dmask0.is_cuda()); TORCH_CHECK(dmask0.is_contiguous()); TORCH_CHECK(dmask0.sizes() == sizes); if (has_x1) { TORCH_CHECK(dmask1_.has_value()); auto dmask1 = dmask1_.value(); TORCH_CHECK(dmask1.dtype() == mtype); TORCH_CHECK(dmask1.is_cuda()); TORCH_CHECK(dmask1.is_contiguous()); TORCH_CHECK(dmask1.sizes() == sizes); } } TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 8192)); TORCH_CHECK(mu.numel() == rows); TORCH_CHECK(mu.sizes() == rsigma.sizes()); // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{dz0.device()}; auto opts = x.options(); auto dx0 = torch::empty(sizes, opts.dtype(itype)); at::Tensor dx1; if (has_x1) { dx1 = torch::empty(sizes, opts.dtype(itype)); } at::Tensor dresidual; if (has_residual) { dresidual = torch::empty_like(x, opts.dtype(rtype)); } auto dgamma0 = torch::empty_like(gamma0); auto dbeta0 = torch::empty_like(gamma0); at::Tensor dgamma1, dbeta1; if (gamma1_.has_value()) { dgamma1 = torch::empty_like(gamma0); dbeta1 = torch::empty_like(gamma0); } layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); launch_params.props = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dropout_p < 1.f); launch_params.params.dropout_keep_p = 1.f - dropout_p; launch_params.params.dresidual = has_residual ? dresidual.data_ptr() : nullptr; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024); auto launcher = get_parallel_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple)); launcher(launch_params, true); auto dgamma0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); auto dbeta0_part = torch::zeros({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); at::Tensor dgamma1_part, dbeta1_part; if (gamma1_.has_value()) { dgamma1_part = torch::zeros_like(dgamma0_part); dbeta1_part = torch::zeros_like(dbeta0_part); } at::Tensor workspace, barrier; layer_norm::BwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; params.x = x.data_ptr(); params.dmask = dropout_p > 0.f ? dmask0_.value().data_ptr() : nullptr; params.dmask1 = (dropout_p > 0.f && has_x1) ? dmask1_.value().data_ptr() : nullptr; params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma0.data_ptr(); params.gamma1 = gamma1_.has_value() ? gamma1_.value().data_ptr() : nullptr; params.dz = dz0.data_ptr(); params.dz1 = dz1_.has_value() ? dz1_.value().data_ptr() : nullptr; params.dx = dx_.has_value() ? dx_.value().data_ptr() : nullptr; params.dx0 = dx0.data_ptr(); params.dx1 = has_x1 ? dx1.data_ptr() : nullptr; params.dbeta = dbeta0.data_ptr(); params.dgamma = dgamma0.data_ptr(); params.dbeta1 = gamma1_.has_value() ? dbeta1.data_ptr() : nullptr; params.dgamma1 = gamma1_.has_value() ? dgamma1.data_ptr() : nullptr; params.dbeta_part = dbeta0_part.data_ptr(); params.dgamma_part = dgamma0_part.data_ptr(); params.dbeta1_part = gamma1_.has_value() ? dbeta1_part.data_ptr() : nullptr; params.dgamma1_part = gamma1_.has_value() ? dgamma1_part.data_ptr() : nullptr; params.dropout_scale = 1.f / (1.f - dropout_p); params.inverse_cols = 1.f / float(params.cols); params.is_rms_norm = is_rms_norm; if( launch_params.barrier_size > 0 ) { // TODO Any way to avoid this? barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } launcher(launch_params, false); std::vector result = { dx0, dx1, dresidual, dgamma0, dbeta0, dgamma1, dbeta1, dgamma0_part, dbeta0_part, dgamma1_part, dbeta1_part }; return result; } //////////////////////////////////////////////////////////////////////////////////////////////////// PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "CUDA DropoutAddLayerNorm"; m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel", py::arg("x0"), py::arg("residual"), py::arg("gamma"), py::arg("beta_"), py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("epsilon"), py::arg("rowscale_const"), py::arg("z_numrows"), py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel", py::arg("dz"), py::arg("dx_"), py::arg("x"), py::arg("x0_"), py::arg("dmask_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma"), py::arg("rowscale_"), py::arg("colscale_"), py::arg("x0_subset_"), py::arg("z_subset_"), py::arg("dropout_p"), py::arg("rowscale_const"), py::arg("x0_numrows"), py::arg("has_residual"), py::arg("is_rms_norm")=false); m.def("dropout_add_ln_parallel_residual_fwd", &dropout_add_ln_parallel_residual_fwd, "Run Dropout + Add + LayerNorm parallel residual forward kernel", py::arg("x0"), py::arg("x1_"), py::arg("residual"), py::arg("gamma0"), py::arg("beta0_"), py::arg("gamma1_"), py::arg("beta1_"), py::arg("dropout_p"), py::arg("epsilon"), py::arg("gen_"), py::arg("residual_in_fp32")=false, py::arg("is_rms_norm")=false); m.def("dropout_add_ln_parallel_residual_bwd", &dropout_add_ln_parallel_residual_bwd, "Run Dropout + Add + LayerNorm parallel residual backward kernel", py::arg("dz0"), py::arg("dz1_"), py::arg("dx_"), py::arg("x"), py::arg("dmask0_"), py::arg("dmask1_"), py::arg("mu"), py::arg("rsigma"), py::arg("gamma0"), py::arg("gamma1_"), py::arg("dropout_p"), py::arg("has_x1"), py::arg("has_residual"), py::arg("is_rms_norm")=false); } ================================================ FILE: csrc/layer_norm/ln_bwd_1024.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_1280.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_1536.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_2048.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_256.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_2560.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_3072.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_4096.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_512.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_5120.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_6144.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_7168.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_768.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_8192.cu ================================================ #include "ln_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); ================================================ FILE: csrc/layer_norm/ln_bwd_kernels.cuh ================================================ #pragma once #include "ln.h" #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "static_switch.h" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_kernel(layer_norm::BwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { COLS = Ktraits::COLS }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using input_t = typename Ktraits::input_t; using compute_t = typename Ktraits::compute_t; using index_t = typename Ktraits::index_t; using mask_t = typename Ktraits::mask_t; using Ivec = typename Ktraits::Ivec; using Rvec = typename Ktraits::Rvec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Mvec = typename Ktraits::Mvec; using Reducer = typename Ktraits::Reducer; using reduce_t = typename Reducer::Type; extern __shared__ char smem_[]; const bool has_residual = params.dresidual != nullptr; const bool prenorm = params.dx != nullptr; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / Ktraits::WARPS_N; const index_t warp_n = warp % Ktraits::WARPS_N; const index_t tid_r = warp_n * THREADS_PER_WARP + lane; const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); const input_t *rowscale = static_cast(params.rowscale); const index_t *x0_subset = static_cast(params.x0_subset); const index_t *z_subset = static_cast(params.z_subset); Cvec dzy_sum[LDGS]; Cvec dz_sum[LDGS]; Cvec dcolscale_sum[LDGS]; memset(dzy_sum, 0, sizeof(dzy_sum)); memset(dz_sum, 0, sizeof(dz_sum)); if (Has_colscale) { memset(dcolscale_sum, 0, sizeof(dcolscale_sum)); } compute_t * smem_wgrad = reinterpret_cast(smem_); char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); Sum sum; const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; Wvec gamma[LDGS]; Wvec colscale[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma[it].load_from(params.gamma, idx); if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // last blocks with syncthreads! // grid stride over rows #pragma unroll 1 for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { const compute_t mu_r = static_cast(params.mu)[row]; const compute_t rs_r = static_cast(params.rs)[row]; const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; const int row_z = !Has_subset ? row + 1 : z_subset[row]; const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; const bool load_dz = !Has_subset || row_z > 0; const bool save_dx0 = !Has_subset || row_x0 > 0; Mvec dmask[LDGS]; Rvec dx[LDGS]; compute_t dy[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS]; compute_t mdy_local = 0.f; compute_t mdyy_local = 0.f; // If dz is not loaded, then dy should be 0 and we don't care about the value of y. if (load_dz) { index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; index_t idx_z = !Has_subset ? idx_x : (load_dz ? (row_z - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Rvec x; Ovec dz; dz.load_from(params.dz, !Has_subset ? idx_x : idx_z); if (prenorm) { dx[it].load_from(params.dx, idx_x); } x.load_from(params.x, idx_x); if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } idx_x += Ktraits::VEC_COLS_PER_LDG; idx_z += Ktraits::VEC_COLS_PER_LDG; idx_x0 += Ktraits::VEC_COLS_PER_LDG; #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t x_tmp = x.data.elt[jt]; compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]); compute_t dz_tmp = dz.data.elt[jt]; mdy_local += dy_tmp; mdyy_local += dy_tmp * y_tmp; dy[it * NUM_ELTS + jt] = dy_tmp; y[it * NUM_ELTS + jt] = y_tmp; dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; dz_sum[it].data.elt[jt] += dz_tmp; } } } } else { index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { if (prenorm) { dx[it].load_from(params.dx, idx_x); } if (Is_dropout) { dmask[it].load_from(params.dmask, !Has_subset ? idx_x : idx_x0); } idx_x += Ktraits::VEC_COLS_PER_LDG; idx_x0 += Ktraits::VEC_COLS_PER_LDG; } } } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; index_t idx_x0 = !Has_subset ? idx_x : (save_dx0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec dx0; Rvec dresidual; Ivec x0; if (Has_colscale && save_dx0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t dx_tmp_res; if (load_dz) { compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; } else { dx_tmp_res = prenorm ? compute_t(dx[it].data.elt[jt]) : 0.f; } if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } if (save_dx0) { compute_t dx0_tmp_res = dx_tmp_res * rowscale_val; if (Is_dropout) { dx0_tmp_res *= params.dropout_scale; if (Has_colscale) { dcolscale_sum[it].data.elt[jt] += dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(x0.data.elt[jt]) : 0.f; dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * compute_t(colscale[it].data.elt[jt]) : 0.f; } else { dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res : 0.f; } } else { if (Has_colscale) { dcolscale_sum[it].data.elt[jt] += dx0_tmp_res * compute_t(x0.data.elt[jt]); dx0.data.elt[jt] = dx0_tmp_res * compute_t(colscale[it].data.elt[jt]); } else { dx0.data.elt[jt] = dx0_tmp_res; } } } } if (has_residual) { dresidual.store_to(params.dresidual, idx_x); } if (save_dx0) { dx0.store_to(params.dx0, !Has_subset ? idx_x : idx_x0); } idx_x += Ktraits::VEC_COLS_PER_LDG; idx_x0 += Ktraits::VEC_COLS_PER_LDG; } } } // end: grid stride loop if( WARPS_M == 1 ) { idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { dz_sum[it].store_to(params.dbeta_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx); if (Has_colscale) { dcolscale_sum[it].store_to(params.dcolscale_part, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } } else { static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); // Finalize reduction of part dgamma and dbeta for this CTA // by reducing over the rows held across the WARPS_M warps // Assumption: blockSize divides hidden size. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dz_sum[NUM_RES]; memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dzy_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dzy_sum[NUM_RES]; memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } compute_t cta_dcolscale_sum[NUM_RES]; if (Has_colscale) { __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dcolscale_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); memset(cta_dcolscale_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dcolscale_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } } const index_t num_valid_writes = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; compute_t *dcolscale_part = Has_colscale ? static_cast(params.dcolscale_part) + bidm * params.cols + tidx : nullptr; for( int jt = 0; jt < NUM_RES; jt++ ) { if (Is_even_cols || (jt < num_valid_writes)) { *dgamma_part = cta_dzy_sum[jt]; dgamma_part += Ktraits::THREADS_PER_CTA; *dbeta_part = cta_dz_sum[jt]; dbeta_part += Ktraits::THREADS_PER_CTA; if (Has_colscale) { *dcolscale_part = cta_dcolscale_sum[jt]; dcolscale_part += Ktraits::THREADS_PER_CTA; } } } } } template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; using Reducer = typename Kernel_traits::Reducer; using reduce_t = typename Reducer::Type; Sum sum; enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; constexpr uint32_t bidm = 0; const uint32_t bidn = blockIdx.x; const uint32_t tidx = threadIdx.x; const uint32_t warp = tidx / THREADS_PER_WARP; const uint32_t lane = tidx % THREADS_PER_WARP; Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { // Each thread sums over NUM_ELT columns. Vec dbeta_local, dgamma_local, dcolscale_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } if (Is_even_cols || col < params.cols) { for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { index_t idx = row * params.cols + col; Vec dbeta_part, dgamma_part, dcolscale_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); if (Has_colscale) { dcolscale_part.load_from(params.dcolscale_part, idx); } #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; if (Has_colscale) { dcolscale_local.data.elt[it] += dcolscale_part.data.elt[it]; } } } } void * smem_gamma = smem_; void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_colscale = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; const int write_row = warp; const int write_col = lane ^ write_row; const int write_idx = write_row * THREADS_PER_WARP + write_col; dgamma_local.store_to(smem_gamma, write_idx); dbeta_local.store_to(smem_beta, write_idx); if (Has_colscale) { dcolscale_local.store_to(smem_colscale, write_idx); } __syncthreads(); // It would be probably safe to reuse the first row of smem_beta and smem_gamma void * smem_gamma_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_beta_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; void * smem_colscale_out = &smem_[Kernel_traits::NUM_FACTORS * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; // More than one iter iff ROWS_PER_CTA < 32. for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { const int read_row = lane; const int read_col = w ^ read_row; const int read_idx = read_row * THREADS_PER_WARP + read_col; memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dgamma_local, 0, sizeof(dgamma_local)); if (Has_colscale) { memset(&dcolscale_local, 0, sizeof(dcolscale_local)); } // Load beta and gamma transposed if(read_row < Kernel_traits::ROWS_PER_CTA){ dbeta_local.load_from(smem_beta, read_idx); dgamma_local.load_from(smem_gamma, read_idx); if (Has_colscale) { dcolscale_local.load_from(smem_colscale, read_idx); } } // Call reducer on the loaded value(s) and convert. #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { compute_t b_i = dbeta_local.data.elt[it]; compute_t g_i = dgamma_local.data.elt[it]; b_i = reducer.allreduce(b_i, sum); g_i = reducer.allreduce(g_i, sum); dgamma_local.data.elt[it] = g_i; dbeta_local.data.elt[it] = b_i; if (Has_colscale) { compute_t cs_i = dcolscale_local.data.elt[it]; cs_i = reducer.allreduce(cs_i, sum); dcolscale_local.data.elt[it] = cs_i; } } // Leader stores the result at the current column. if(lane == 0){ dgamma_local.store_to(smem_gamma_out, w); dbeta_local.store_to(smem_beta_out, w); if (Has_colscale) { dcolscale_local.store_to(smem_colscale_out, w); } } } // All writes done. __syncthreads(); // Pack and store: 2-wide stores with half the threads. if (Is_even_cols || col_out * 2 < params.cols) { if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { using src_t = typename TypeToVec2::Type; using dst_t = typename TypeToVec2::Type; Vec dbeta_vec2, dgamma_vec2, dcolscale_vec2; Vec dbeta_out2, dgamma_out2, dcolscale_out2; dgamma_vec2.load_from(smem_gamma_out, lane); dbeta_vec2.load_from(smem_beta_out, lane); if (Has_colscale) { dcolscale_vec2.load_from(smem_colscale_out, lane); } #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); if (Has_colscale) { dcolscale_out2.data.elt[it] = Converter::convert(dcolscale_vec2.data.elt[it]); } } dgamma_out2.store_to(params.dgamma, col_out); dbeta_out2.store_to(params.dbeta, col_out); if (Has_colscale) { dcolscale_out2.store_to(params.dcolscale, col_out); } } } } } } // namespace layer_norm using namespace layer_norm; template< typename weight_t, typename input_t, typename residual_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL > void launch_(LaunchParams &launch_params, const bool configure_params){ using Kernel_traits = Kernel_traits; bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool has_colscale = launch_params.params.colscale != nullptr; bool has_subset = launch_params.params.x0_subset != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { BOOL_SWITCH(has_subset, HasSubsetConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { auto kernel = &ln_bwd_kernel; if( configure_params ) { int ctas_per_sm; CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2; } return; } if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); } using Kernel_traits_f = layer_norm::Kernel_traits_finalize; auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; kernel_f<<>>(launch_params.params); }); }); }); }); } ================================================ FILE: csrc/layer_norm/ln_fwd_1024.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_1280.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_1536.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_2048.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_256.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_2560.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_3072.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_4096.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_512.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_5120.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_6144.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_7168.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_768.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_8192.cu ================================================ #include "ln_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); ================================================ FILE: csrc/layer_norm/ln_fwd_kernels.cuh ================================================ #pragma once #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include // For at::cuda::philox::unpack #include #include "ln.h" #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "static_switch.h" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_kernel(FwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using input_t = typename Ktraits::input_t; using residual_t = typename Ktraits::residual_t; using output_t = typename Ktraits::output_t; using index_t = typename Ktraits::index_t; using compute_t = typename Ktraits::compute_t; using mask_t = typename Ktraits::mask_t; using Ivec = typename Ktraits::Ivec; using Rvec = typename Ktraits::Rvec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Mvec = typename Ktraits::Mvec; using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; const bool has_residual = params.residual != nullptr; const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); extern __shared__ char smem_[]; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / WARPS_N; const index_t warp_n = warp % WARPS_N; const index_t r = bidm * ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); compute_t *mu_ptr = static_cast(params.mu); compute_t *rs_ptr = static_cast(params.rs); const input_t *rowscale = static_cast(params.rowscale); const index_t *x0_subset = static_cast(params.x0_subset); const index_t *z_subset = static_cast(params.z_subset); // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu curandStatePhilox4_32_10_t state; if (Is_dropout) { auto seeds = at::cuda::philox::unpack(params.philox_args); const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); } const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; Wvec gamma[LDGS]; Wvec beta[LDGS]; Wvec colscale[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma[it].load_from(params.gamma, idx); if (params.beta != nullptr) { beta[it].load_from(params.beta, idx); } else { beta[it].zero_(); } if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } idx += VEC_COLS_PER_LDG; } } for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; const int row_z = !Has_subset ? row + 1 : z_subset[row]; const bool load_x0 = !Has_subset || row_x0 > 0; index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); compute_t xf[LDGS * NUM_ELTS]; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec x0; Rvec residual; Rvec x; Mvec dmask; if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } if (has_residual) { residual.load_from(params.residual, idx_x); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use // the more efficient curand_uniform4. compute_t x_ij; if (load_x0) { mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; if (Is_dropout) { dmask.data.elt[jt] = keep; } compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; } else { x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; } if (save_x) { x.data.elt[jt] = x_ij; } xf[it * NUM_ELTS + jt] = x_ij; } if (save_x) { x.store_to(params.x, idx_x); } if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } idx_x += VEC_COLS_PER_LDG; idx_x0 += VEC_COLS_PER_LDG; } } static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { // Need to convert to int, otherwise the subtraction will wrap around. const index_t valid_partial_vecs_in_warp = std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), int(THREADS_PER_WARP)); return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; }; stats_t s = stats.template compute( xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS ); compute_t mu = layer_norm::Get<0>::of(s); compute_t m2 = layer_norm::Get<1>::of(s); if( bidn == 0 && warp_n == 0 && lane == 0 ) { mu_ptr[row] = mu; } compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); if( bidn == 0 && warp_n == 0 && lane == 0 ) { rs_ptr[row] = rs; } const bool save_z = !Has_subset || row_z > 0; if (save_z) { index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ovec z; #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); compute_t g_ij = gamma[it].data.elt[jt]; compute_t b_ij = beta[it].data.elt[jt]; z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); } z.store_to(params.z, idx_z); idx_z += VEC_COLS_PER_LDG; } } } } } } // namespace layer_norm using namespace layer_norm; template< typename weight_t, typename input_t, typename residual_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG > void launch_(LaunchParams &launch_params, const bool configure_params){ using Kernel_traits = Kernel_traits; bool has_colscale = launch_params.params.colscale != nullptr; bool has_subset = launch_params.params.x0_subset != nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { BOOL_SWITCH(has_subset, HasSubsetConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { auto kernel = &ln_fwd_kernel; if( configure_params ) { int ctas_per_sm; CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2; } return; } if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); } }); }); }); }); } ================================================ FILE: csrc/layer_norm/ln_kernel_traits.h ================================================ #pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// namespace layer_norm { template< uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename residual_t_, typename output_t_, typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_ > struct Kernel_traits_base { using weight_t = weight_t_; using input_t = input_t_; using residual_t = residual_t_; using output_t = output_t_; using compute_t = compute_t_; using index_t = index_t_; enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; enum { THREADS_PER_WARP = 32 }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename residual_t_, typename output_t_, typename compute_t_, typename index_t_, bool Has_colscale, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_, typename Base = Kernel_traits_base > struct Kernel_traits_finalize : public Base { enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); // Bytes per global load from the input. enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; // Number of elements fetched by a global load. enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; // Bytes per global store of the weights. enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); // The total number of BYTES_PER_LDG-wide words in a hidden vector. enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); // Shared memory size to transpose the CTA result. enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; // Shared memory size to coalsece the CTA result. enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; // Shared memory requirement per CTA. static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; // The type of the reducer. using Reducer = layer_norm::Reducer; // Condition for the whole CTA to participate in syncthreads. static_assert(COLS % Base::THREADS_PER_WARP == 0); enum { CTAS = COLS / Base::THREADS_PER_WARP }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template< typename weight_t_, typename input_t_, typename residual_t_, typename output_t_, typename compute_t_, typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_, uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16, typename Base = Kernel_traits_base< HIDDEN_SIZE_, weight_t_, input_t_, residual_t_, output_t_, compute_t_, index_t_, WARPS_M_*WARPS_N_*THREADS_PER_WARP > > struct Kernel_traits : public Base { using input_t = typename Base::input_t; using residual_t = typename Base::residual_t; using weight_t = typename Base::weight_t; using compute_t = typename Base::compute_t; using output_t = typename Base::output_t; using index_t = typename Base::index_t; // using mask_t = unsigned char; using mask_t = bool; enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; enum { WARPS_M = WARPS_M_ }; enum { WARPS_N = WARPS_N_ }; enum { COLS = HIDDEN_SIZE_ }; enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; enum { ROWS_PER_CTA = WARPS_M }; enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); using reduce_t = typename layer_norm::TypeToVec2::Type; using Reducer = layer_norm::Reducer; enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; using Ivec = layer_norm::Vec; using Rvec = layer_norm::Vec; using Ovec = layer_norm::Vec; using Wvec = layer_norm::Vec; using Cvec = layer_norm::Vec; using Mvec = layer_norm::Vec; enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; // Assume that each thread can handle the same number of elements in the output and weights as in the input. static_assert(sizeof(input_t) == sizeof(output_t)); static_assert(sizeof(input_t) <= sizeof(residual_t)); // The number of columns fetched per load from input: one per thread. enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; // The total number of vectorized loads/stores per hidden vector. enum { VEC_COLS = COLS / ELTS_PER_LDG }; // The number of loads per thread for the input. enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); using Stats = layer_norm::Stats; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_1024.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_1280.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_1536.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_2048.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_256.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_2560.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_3072.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_4096.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL // Use 8 warps otherwise there's a lot of register spilling REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_512.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_5120.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL // Use 8 warps otherwise there's a lot of register spilling REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_6144.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_7168.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 8, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 8, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_768.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_bwd_8192.cu ================================================ #include "ln_parallel_residual_bwd_kernels.cuh" // Create backward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_PARALLEL_BWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_1024.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_1280.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_1536.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_2048.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_256.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_2560.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_3072.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_4096.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_512.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_5120.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_6144.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_7168.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 7168, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_768.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_fwd_8192.cu ================================================ #include "ln_parallel_residual_fwd_kernels.cuh" // Create forward launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_PARALLEL_FWD_LAUNCHER( 8192, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16); ================================================ FILE: csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh ================================================ #pragma once #include "ln.h" #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "static_switch.h" #include "ln_bwd_kernels.cuh" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_parallel_residual_bwd_kernel(layer_norm::BwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { COLS = Ktraits::COLS }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using input_t = typename Ktraits::input_t; using compute_t = typename Ktraits::compute_t; using index_t = typename Ktraits::index_t; using mask_t = typename Ktraits::mask_t; using Ivec = typename Ktraits::Ivec; using Rvec = typename Ktraits::Rvec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Mvec = typename Ktraits::Mvec; using Reducer = typename Ktraits::Reducer; using reduce_t = typename Reducer::Type; extern __shared__ char smem_[]; const bool has_residual = params.dresidual != nullptr; const bool has_x1 = params.dx1 != nullptr; const bool prenorm = params.dx != nullptr; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / Ktraits::WARPS_N; const index_t warp_n = warp % Ktraits::WARPS_N; const index_t tid_r = warp_n * THREADS_PER_WARP + lane; const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); Cvec dz0y_sum[LDGS]; Cvec dz0_sum[LDGS]; Cvec dz1y_sum[LDGS]; Cvec dz1_sum[LDGS]; memset(dz0y_sum, 0, sizeof(dz0y_sum)); memset(dz0_sum, 0, sizeof(dz0_sum)); if (!Tied_norm) { memset(dz1y_sum, 0, sizeof(dz1y_sum)); memset(dz1_sum, 0, sizeof(dz1_sum)); } compute_t * smem_wgrad = reinterpret_cast(smem_); char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); Sum sum; const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG; Wvec gamma0[LDGS]; Wvec gamma1[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma0[it].load_from(params.gamma, idx); if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // last blocks with syncthreads! // grid stride over rows #pragma unroll 1 for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { const compute_t mu_r = static_cast(params.mu)[row]; const compute_t rs_r = static_cast(params.rs)[row]; Mvec dmask0[LDGS], dmask1[LDGS]; Rvec dx[LDGS]; compute_t dy[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS]; compute_t mdy_local = 0.f; compute_t mdyy_local = 0.f; index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Rvec x; Ovec dz0, dz1; dz0.load_from(params.dz, idx); if (!Tied_norm) { dz1.load_from(params.dz1, idx); } if (prenorm) { dx[it].load_from(params.dx, idx); } x.load_from(params.x, idx); if (Is_dropout) { dmask0[it].load_from(params.dmask, idx); if (has_x1) { dmask1[it].load_from(params.dmask1, idx); } } idx += Ktraits::VEC_COLS_PER_LDG; #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t x_tmp = x.data.elt[jt]; compute_t y_tmp = rs_r * (x_tmp - (!params.is_rms_norm ? mu_r : 0.f)); compute_t dy_tmp = compute_t(gamma0[it].data.elt[jt]) * compute_t(dz0.data.elt[jt]); if (!Tied_norm) { dy_tmp += compute_t(gamma1[it].data.elt[jt]) * compute_t(dz1.data.elt[jt]); } compute_t dz0_tmp = dz0.data.elt[jt]; compute_t dz1_tmp; if (!Tied_norm) { dz1_tmp = dz1.data.elt[jt]; } mdy_local += dy_tmp; mdyy_local += dy_tmp * y_tmp; dy[it * NUM_ELTS + jt] = dy_tmp; y[it * NUM_ELTS + jt] = y_tmp; dz0y_sum[it].data.elt[jt] += dz0_tmp * y_tmp; dz0_sum[it].data.elt[jt] += dz0_tmp; if (!Tied_norm) { dz1y_sum[it].data.elt[jt] += dz1_tmp * y_tmp; dz1_sum[it].data.elt[jt] += dz1_tmp; } } } } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); mdy_local = layer_norm::Get<0>::of(result) * params.inverse_cols; mdyy_local = layer_norm::Get<1>::of(result) * params.inverse_cols; idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec dx0, dx1; Rvec dresidual; #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t dx_tmp_res; compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + (!params.is_rms_norm ? mdy_local : 0.f))); dx_tmp_res = prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; if (has_residual) { dresidual.data.elt[jt] = dx_tmp_res; } if (Is_dropout) { dx0.data.elt[jt] = dmask0[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; if (has_x1) { dx1.data.elt[jt] = dmask1[it].data.elt[jt] ? dx_tmp_res * params.dropout_scale : 0.f; } } else { dx0.data.elt[jt] = dx_tmp_res; if (has_x1) { dx1.data.elt[jt] = dx_tmp_res; } } } if (has_residual) { dresidual.store_to(params.dresidual, idx); } dx0.store_to(params.dx0, idx); if (has_x1) { dx1.store_to(params.dx1, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } } // end: grid stride loop if( WARPS_M == 1 ) { idx = r * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { dz0_sum[it].store_to(params.dbeta_part, idx); dz0y_sum[it].store_to(params.dgamma_part, idx); if (!Tied_norm) { dz1_sum[it].store_to(params.dbeta1_part, idx); dz1y_sum[it].store_to(params.dgamma1_part, idx); } idx += Ktraits::VEC_COLS_PER_LDG; } } } else { static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); // Finalize reduction of part dgamma and dbeta for this CTA // by reducing over the rows held across the WARPS_M warps // Assumption: blockSize divides hidden size. enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz0_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dz0_sum[NUM_RES]; memset(cta_dz0_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dz0_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz0y_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); compute_t cta_dz0y_sum[NUM_RES]; memset(cta_dz0y_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dz0y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } compute_t cta_dz1_sum[NUM_RES], cta_dz1y_sum[NUM_RES]; if (!Tied_norm) { __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz1_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); memset(cta_dz1_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dz1_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } __syncthreads(); idx = warp_m * Ktraits::VEC_COLS + tid_r; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz1y_sum[it].store_to(smem_wgrad, idx); idx += THREADS_PER_ROW; } __syncthreads(); memset(cta_dz1y_sum, 0, sizeof(compute_t) * NUM_RES); for( int it = 0; it < ROWS_PER_CTA; it++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) { cta_dz1y_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; } } } const index_t num_valid_writes = (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA; compute_t *dgamma0_part = static_cast(params.dgamma_part) + bidm * params.cols + tidx; compute_t *dbeta0_part = static_cast(params.dbeta_part) + bidm * params.cols + tidx; compute_t *dgamma1_part = !Tied_norm ? static_cast(params.dgamma1_part) + bidm * params.cols + tidx : nullptr; compute_t *dbeta1_part = !Tied_norm ? static_cast(params.dbeta1_part) + bidm * params.cols + tidx : nullptr; for( int jt = 0; jt < NUM_RES; jt++ ) { if (Is_even_cols || (jt < num_valid_writes)) { *dgamma0_part = cta_dz0y_sum[jt]; dgamma0_part += Ktraits::THREADS_PER_CTA; *dbeta0_part = cta_dz0_sum[jt]; dbeta0_part += Ktraits::THREADS_PER_CTA; if (!Tied_norm) { *dgamma1_part = cta_dz1y_sum[jt]; dgamma1_part += Ktraits::THREADS_PER_CTA; *dbeta1_part = cta_dz1_sum[jt]; dbeta1_part += Ktraits::THREADS_PER_CTA; } } } } } template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_parallel_residual_bwd_finalize_kernel(BwdParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; using Reducer = typename Kernel_traits::Reducer; using reduce_t = typename Reducer::Type; Sum sum; enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; // Multiplying by 2 since we have both gamma0 and gamma1 __shared__ char smem_[2 * Kernel_traits::SMEM_BYTES_PER_CTA]; constexpr uint32_t bidm = 0; const uint32_t bidn = blockIdx.x; const uint32_t tidx = threadIdx.x; const uint32_t warp = tidx / THREADS_PER_WARP; const uint32_t lane = tidx % THREADS_PER_WARP; Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { // Each thread sums over NUM_ELT columns. Vec dbeta0_local, dgamma0_local, dbeta1_local, dgamma1_local; memset(&dgamma0_local, 0, sizeof(dgamma0_local)); memset(&dbeta0_local, 0, sizeof(dbeta0_local)); memset(&dgamma1_local, 0, sizeof(dgamma1_local)); memset(&dbeta1_local, 0, sizeof(dbeta1_local)); if (Is_even_cols || col < params.cols) { for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { index_t idx = row * params.cols + col; Vec dbeta0_part, dgamma0_part, dbeta1_part, dgamma1_part; dbeta0_part.load_from(params.dbeta_part, idx); dgamma0_part.load_from(params.dgamma_part, idx); dbeta1_part.load_from(params.dbeta1_part, idx); dgamma1_part.load_from(params.dgamma1_part, idx); #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma0_local.data.elt[it] += dgamma0_part.data.elt[it]; dbeta0_local.data.elt[it] += dbeta0_part.data.elt[it]; dgamma1_local.data.elt[it] += dgamma1_part.data.elt[it]; dbeta1_local.data.elt[it] += dbeta1_part.data.elt[it]; } } } void * smem_gamma0 = smem_; void * smem_beta0 = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_gamma1 = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_beta1 = &smem_[3 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; const int write_row = warp; const int write_col = lane ^ write_row; const int write_idx = write_row * THREADS_PER_WARP + write_col; dgamma0_local.store_to(smem_gamma0, write_idx); dbeta0_local.store_to(smem_beta0, write_idx); dgamma1_local.store_to(smem_gamma1, write_idx); dbeta1_local.store_to(smem_beta1, write_idx); __syncthreads(); // It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0 void * smem_gamma0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_beta0_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; void * smem_gamma1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 2 * Kernel_traits::SMEM_BYTES_OUTPUT]; void * smem_beta1_out = &smem_[4 * Kernel_traits::SMEM_BYTES_TRANSPOSE + 3 * Kernel_traits::SMEM_BYTES_OUTPUT]; // More than one iter iff ROWS_PER_CTA < 32. for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { const int read_row = lane; const int read_col = w ^ read_row; const int read_idx = read_row * THREADS_PER_WARP + read_col; memset(&dbeta0_local, 0, sizeof(dbeta0_local)); memset(&dgamma0_local, 0, sizeof(dgamma0_local)); memset(&dbeta1_local, 0, sizeof(dbeta1_local)); memset(&dgamma1_local, 0, sizeof(dgamma1_local)); // Load beta and gamma transposed if(read_row < Kernel_traits::ROWS_PER_CTA){ dbeta0_local.load_from(smem_beta0, read_idx); dgamma0_local.load_from(smem_gamma0, read_idx); dbeta1_local.load_from(smem_beta1, read_idx); dgamma1_local.load_from(smem_gamma1, read_idx); } // Call reducer on the loaded value(s) and convert. #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { compute_t b0_i = dbeta0_local.data.elt[it]; compute_t g0_i = dgamma0_local.data.elt[it]; compute_t b1_i = dbeta1_local.data.elt[it]; compute_t g1_i = dgamma1_local.data.elt[it]; b0_i = reducer.allreduce(b0_i, sum); g0_i = reducer.allreduce(g0_i, sum); b1_i = reducer.allreduce(b1_i, sum); g1_i = reducer.allreduce(g1_i, sum); dgamma0_local.data.elt[it] = g0_i; dbeta0_local.data.elt[it] = b0_i; dgamma1_local.data.elt[it] = g1_i; dbeta1_local.data.elt[it] = b1_i; } // Leader stores the result at the current column. if(lane == 0){ dgamma0_local.store_to(smem_gamma0_out, w); dbeta0_local.store_to(smem_beta0_out, w); dgamma1_local.store_to(smem_gamma1_out, w); dbeta1_local.store_to(smem_beta1_out, w); } } // All writes done. __syncthreads(); // Pack and store: 2-wide stores with half the threads. if (Is_even_cols || col_out * 2 < params.cols) { if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { using src_t = typename TypeToVec2::Type; using dst_t = typename TypeToVec2::Type; Vec dbeta0_vec2, dgamma0_vec2, dbeta1_vec2, dgamma1_vec2; Vec dbeta0_out2, dgamma0_out2, dbeta1_out2, dgamma1_out2; dgamma0_vec2.load_from(smem_gamma0_out, lane); dbeta0_vec2.load_from(smem_beta0_out, lane); dgamma1_vec2.load_from(smem_gamma1_out, lane); dbeta1_vec2.load_from(smem_beta1_out, lane); #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { dgamma0_out2.data.elt[it] = Converter::convert(dgamma0_vec2.data.elt[it]); dbeta0_out2.data.elt[it] = Converter::convert(dbeta0_vec2.data.elt[it]); dgamma1_out2.data.elt[it] = Converter::convert(dgamma1_vec2.data.elt[it]); dbeta1_out2.data.elt[it] = Converter::convert(dbeta1_vec2.data.elt[it]); } dgamma0_out2.store_to(params.dgamma, col_out); dbeta0_out2.store_to(params.dbeta, col_out); dgamma1_out2.store_to(params.dgamma1, col_out); dbeta1_out2.store_to(params.dbeta1, col_out); } } } } } // namespace layer_norm using namespace layer_norm; template< typename weight_t, typename input_t, typename residual_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL > void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ using Kernel_traits = Kernel_traits; bool is_dropout = launch_params.params.dropout_keep_p < 1.f; bool tied_norm = launch_params.params.gamma1 == nullptr; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { BOOL_SWITCH(tied_norm, TiedNormConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { auto kernel = &ln_parallel_residual_bwd_kernel; if( configure_params ) { int ctas_per_sm; CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::reduce_t) * 2; } return; } if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); } using Kernel_traits_f = layer_norm::Kernel_traits_finalize; auto kernel_f = !TiedNormConst ? &layer_norm::ln_parallel_residual_bwd_finalize_kernel : &layer_norm::ln_bwd_finalize_kernel; kernel_f<<>>(launch_params.params); }); }); }); } ================================================ FILE: csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh ================================================ #pragma once #ifdef OLD_GENERATOR_PATH #include #else #include #endif #include // For at::cuda::philox::unpack #include #include "ln.h" #include "ln_utils.cuh" #include "ln_kernel_traits.h" #include "static_switch.h" namespace layer_norm { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_parallel_residual_fwd_kernel(FwdParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; using input_t = typename Ktraits::input_t; using residual_t = typename Ktraits::residual_t; using output_t = typename Ktraits::output_t; using index_t = typename Ktraits::index_t; using compute_t = typename Ktraits::compute_t; using mask_t = typename Ktraits::mask_t; using Ivec = typename Ktraits::Ivec; using Rvec = typename Ktraits::Rvec; using Ovec = typename Ktraits::Ovec; using Wvec = typename Ktraits::Wvec; using Cvec = typename Ktraits::Cvec; using Mvec = typename Ktraits::Mvec; using Stats = typename Ktraits::Stats; using stats_t = typename Stats::stats_t; const bool has_residual = params.residual != nullptr; const bool has_x1 = params.x1 != nullptr; const bool save_x = has_residual || has_x1 || Is_dropout || !(std::is_same::value); extern __shared__ char smem_[]; const index_t tidx = threadIdx.x; const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t lane = tidx % THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP; const index_t warp_m = warp / WARPS_N; const index_t warp_n = warp % WARPS_N; const index_t r = bidm * ROWS_PER_CTA + warp_m; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); compute_t *mu_ptr = static_cast(params.mu); compute_t *rs_ptr = static_cast(params.rs); // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu curandStatePhilox4_32_10_t state; if (Is_dropout) { auto seeds = at::cuda::philox::unpack(params.philox_args); const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); } const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; Wvec gamma0[LDGS]; Wvec beta0[LDGS]; Wvec gamma1[LDGS]; Wvec beta1[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { gamma0[it].load_from(params.gamma, idx); if (params.beta != nullptr) { beta0[it].load_from(params.beta, idx); } else { beta0[it].zero_(); } if (!Tied_norm) { gamma1[it].load_from(params.gamma1, idx); if (params.beta1 != nullptr) { beta1[it].load_from(params.beta1, idx); } else { beta1[it].zero_(); } } idx += VEC_COLS_PER_LDG; } } for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; compute_t xf[LDGS * NUM_ELTS]; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ivec x0; Ivec x1; Rvec residual; Rvec x; Mvec dmask0; Mvec dmask1; x0.load_from(params.x0, idx); if (has_x1) { x1.load_from(params.x1, idx); } if (has_residual) { residual.load_from(params.residual, idx); } #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use // the more efficient curand_uniform4. compute_t x_ij; mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; if (Is_dropout) { dmask0.data.elt[jt] = keep0; } compute_t x0_ij = compute_t(x0.data.elt[jt]); x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; if (has_x1) { mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p; if (Is_dropout) { dmask1.data.elt[jt] = keep1; } compute_t x1_ij = compute_t(x1.data.elt[jt]); x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f; x_ij = has_residual ? x0_ij + x1_ij + compute_t(residual.data.elt[jt]) : x0_ij + x1_ij; } else { x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; } if (save_x) { x.data.elt[jt] = x_ij; } xf[it * NUM_ELTS + jt] = x_ij; } if (save_x) { x.store_to(params.x, idx); } if (Is_dropout) { dmask0.store_to(params.dmask, idx); if (has_x1) { dmask1.store_to(params.dmask1, idx); } } idx += VEC_COLS_PER_LDG; } } static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { // Need to convert to int, otherwise the subtraction will wrap around. const index_t valid_partial_vecs_in_warp = std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), int(THREADS_PER_WARP)); return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; }; stats_t s = stats.template compute( xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS ); compute_t mu = layer_norm::Get<0>::of(s); compute_t m2 = layer_norm::Get<1>::of(s); if( bidn == 0 && warp_n == 0 && lane == 0 ) { mu_ptr[row] = mu; } compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); if( bidn == 0 && warp_n == 0 && lane == 0 ) { rs_ptr[row] = rs; } idx = row * params.cols / Ktraits::ELTS_PER_LDG + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { if (Is_even_cols || (it < num_valid_ldgs)) { Ovec z0; Ovec z1; #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); compute_t g0_ij = gamma0[it].data.elt[jt]; compute_t b0_ij = beta0[it].data.elt[jt]; z0.data.elt[jt] = output_t(g0_ij * y_ij + b0_ij); if (!Tied_norm) { compute_t g1_ij = gamma1[it].data.elt[jt]; compute_t b1_ij = beta1[it].data.elt[jt]; z1.data.elt[jt] = output_t(g1_ij * y_ij + b1_ij); } } z0.store_to(params.z, idx); if (!Tied_norm) { z1.store_to(params.z1, idx); } idx += VEC_COLS_PER_LDG; } } } } } // namespace layer_norm using namespace layer_norm; template< typename weight_t, typename input_t, typename residual_t, typename output_t, typename compute_t, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, int BYTES_PER_LDG > void launch_parallel_residual_(LaunchParams &launch_params, const bool configure_params){ using Kernel_traits = Kernel_traits; bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; bool tied_norm = launch_params.params.gamma1 == nullptr; BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { BOOL_SWITCH(tied_norm, TiedNormConst, [&] { BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { auto kernel = &ln_parallel_residual_fwd_kernel; if( configure_params ) { int ctas_per_sm; CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; launch_params.barrier_size = 0; launch_params.workspace_bytes = 0; if(Kernel_traits::CTAS_PER_ROW > 1) { launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * sizeof(typename Kernel_traits::Stats::stats_t) * 2; } return; } if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); } auto stream = launch_params.stream; auto ctas_per_col = launch_params.params.ctas_per_col; if( Kernel_traits::CTAS_PER_ROW == 1 ) { kernel<<>>(launch_params.params); } else { dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = (void *)&launch_params.params; cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); } }); }); }); } ================================================ FILE: csrc/layer_norm/ln_utils.cuh ================================================ #pragma once #include #include #include #include "ln.h" //////////////////////////////////////////////////////////////////////////////////////////////////// constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// inline void check_cuda_(cudaError_t status, const char *file, int line) { if( status != cudaSuccess ) { fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); exit(status); } } //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA(ans) \ { check_cuda_((ans), __FILE__, __LINE__); } //////////////////////////////////////////////////////////////////////////////////////////////////// #define DIVUP(x, y) (((x) + ((y)-1)) / (y)) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ const bool configure_params) { \ launch_( \ launch_params, configure_params); \ } \ static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_BWD_LAUNCHER( \ HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ const bool configure_params) { \ launch_(launch_params, configure_params); \ } \ static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ const bool configure_params) { \ launch_parallel_residual_( \ launch_params, configure_params); \ } \ static FwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// #define REGISTER_PARALLEL_BWD_LAUNCHER( \ HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ const bool configure_params) { \ launch_parallel_residual_(launch_params, configure_params); \ } \ static BwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float2 operator+(const float2 & a, const float2 & b){ return {a.x + b.x, a.y + b.y}; } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void operator+=(float2 & a, const float2 & b){ a.x += b.x; a.y += b.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Sum { inline __device__ Sum(){} inline __device__ T operator()(const T &a, const T &b){ return a + b; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ return __shfl_xor_sync(uint32_t(-1), x, idx); } template<> inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; } template inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ return __shfl_down_sync(uint32_t(-1), x, idx); } template<> inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; } //////////////////////////////////////////////////////////////////////////////////////////////////// namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint16 { uint4 u; uint4 v; uint4 s; uint4 t; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct uint8 { uint4 u; uint4 v; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BytesToType {}; template<> struct BytesToType<64> { using Type = uint16; static_assert(sizeof(Type) == 64); }; template<> struct BytesToType<32> { using Type = uint8; static_assert(sizeof(Type) == 32); }; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeToVec2 {}; template<> struct TypeToVec2 { using Type = float2; }; template<> struct TypeToVec2 { using Type = half2; }; template<> struct TypeToVec2 { using Type = nv_bfloat162; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Get { template static inline __device__ R of(const T &vec); }; template<> template inline __device__ R Get<0>::of(const T &vec) { return vec.x; } template<> template inline __device__ R Get<1>::of(const T &vec) { return vec.y; } template<> template inline __device__ R Get<2>::of(const T &vec) { return vec.z; } template<> template inline __device__ R Get<3>::of(const T &vec) { return vec.w; } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Converter{ static inline __device__ Dst convert(const Src &from) { return Dst(from); } }; template<> struct Converter{ static inline __device__ half2 convert(const float2 &x) { return __float22half2_rn(x); } }; template<> struct Converter{ static inline __device__ nv_bfloat162 convert(const float2 &x) { #if __CUDA_ARCH__ >= 800 return __float22bfloat162_rn(x); #else union { nv_bfloat162 raw; nv_bfloat16 x; nv_bfloat16 y; } tmp; tmp.x = __float2bfloat16_rn(x.x); tmp.y = __float2bfloat16_rn(x.y); return tmp.raw; #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Zeros{ static inline __device__ T get() { return T(0.f); } }; template<> struct Zeros{ static inline __device__ float2 get() { return make_float2(0.f, 0.f); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Vec { enum { BYTES = NUM_ELT * sizeof(Elt_type) }; using Vec_type = typename BytesToType::Type; using Alias_type = union { Vec_type vec; Elt_type elt[NUM_ELT]; }; Alias_type data; template inline __device__ void to(Vec &other) { #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { other.data.elt[it] = S(this->data.elt[it]); } } template inline __device__ void assign(const Op &op) { #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { this->data.elt[it] = op(it); } } inline __device__ void zero_() { #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { this->data.elt[it] = Elt_type(0.f); } } inline __device__ void load_from(const void *base_ptr, const size_t idx) { this->data.vec = static_cast(base_ptr)[idx]; } inline __device__ void store_to(void *base_ptr, const size_t idx) { static_cast(base_ptr)[idx] = this->data.vec; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct InterCTASync { template inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) : phase_counter_(0) , b0_(params.barrier + bidm) // The barrier for this group of CTAs. , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. { // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! } inline __device__ void spin_wait_(int *barrier, int step, int expected) { asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); for( int found = -1; found != expected; ) { asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); } } inline __device__ void sync(){ // ALL THREADS MUST ENTER! // We switch barrier every iteration. int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; // We decrement every other iteration. bool dec = phase_counter_ & 0x2; int step = dec ? -1 : 1; int expected = dec ? 0 : CTAS_PER_ROW; // There are only 4 phases: up/down for b0/b1. phase_counter_ = (phase_counter_ + 1) & 0x3; if( threadIdx.x == 0 ) { spin_wait_(barrier, step, expected); } // CTA waits for thread 0 __syncthreads(); } int phase_counter_; int * b0_; int * b1_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer : public Reducer { using InterCTASync = InterCTASync; using Base = Reducer; using Type = typename Base::Type; enum { SMEM_BYTES = Base::SMEM_BYTES }; enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) , inter_cta_(params, bidm, bidn) , bidn_(bidn) // CTA id within the group. , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) { } template inline __device__ T allreduce(T data, Op &op) { data = Base::reduce(data, op); // We switch workspace every iteration. T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; // Warp leaders 0 hold the CTA-local results. if( this->warp_n_ == 0 && this->lane_ == 0 ) { workspace[bidn_] = data; } inter_cta_.sync(); static_assert(CTAS_PER_ROW <= 32); T total = Zeros::get(); if(this->lane_ < CTAS_PER_ROW){ total = workspace[this->lane_]; } total = Reducer::allreduce_(total, op); return total; } InterCTASync inter_cta_; T *w0_; T *w1_; int bidn_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer { using Type = T; enum { SMEM_BYTES = 0 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { THREADS_PER_WARP = 32 }; template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : warp_n_(warp_n) , lane_(lane) { } template static inline __device__ T allreduce_(T data, Op &op) { #pragma unroll for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { data = op(data, warp_shuffle_xor(data, it)); } return data; } template inline __device__ T allreduce(T data, Op &op) { return allreduce_(data, op); } template inline __device__ T reduce(T data, Op &op){ // only lane 0 holds the result! #pragma unroll for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { data = op(data, warp_shuffle_down(data, it)); } return data; } int warp_n_; int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Reducer : public Reducer { using Base = Reducer; using Type = T; enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; enum { WORKSPACE_BYTES_PER_GROUP = 0 }; enum { THREADS_PER_WARP = 32 }; template inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) , use0_(true) { smem0_ = &static_cast(smem)[warp_m * WARPS_N]; smem1_ = smem0_ + WARPS_M * WARPS_N; } template inline __device__ T allreduce(T data, Op & op) { T * smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; data = Base::reduce(data, op); if( this->lane_ == 0 ) { smem[this->warp_n_] = data; } __syncthreads(); T out = Zeros::get(); #pragma unroll for( int it = 0; it < WARPS_N; it++ ) { out = op(out, smem[it]); } return out; } template inline __device__ T reduce(T data, Op &op) { T * smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; // only intra-CTA group leader holds the result! data = Base::reduce(data, op); if( this->lane_ == 0 ) { smem[this->warp_n_] = data; } __syncthreads(); T out = Zeros::get(); if( this->warp_n_ == 0 && this->lane_ == 0 ) { #pragma unroll for( int it = 0; it < WARPS_N; it++ ) { out = op(out, smem[it]); } } return out; } T * smem0_; T * smem1_; bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); #pragma unroll for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { // Exchange int_t n_b = warp_shuffle_down(n_a, step); T m_b = warp_shuffle_down(m_a, step); T m2_b = warp_shuffle_down(m2_a, step); // Update const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( const T delta = m_a - m_b; const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; n_a = n_ab; m_a = m_ab; m2_a = m2_ab; } // Intra-warp broadcast (only lane 0 has valid stats). m_a = __shfl_sync(uint32_t(-1), m_a, 0); m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. using InterCTASync = InterCTASync; using BlockStats = Stats; using stats_t = typename BlockStats::stats_t; enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; template inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : inter_cta_(params, bidm, bidn) , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) , bidn_(bidn) // CTA id within the group. , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) , warp_n_(warp_n) , lane_(lane) { } template inline __device__ stats_t compute(const T (&elts)[N], const T rn) { constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; // TODO rn is not really needed here.. constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); stats_t block_stats = block_stats_.compute(elts, block_rn); stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; if( warp_n_ == 0 && lane_ == 0 ) { workspace[bidn_] = block_stats; } // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. inter_cta_.sync(); T n = Zeros::get(); T m = Zeros::get(); T m2 = Zeros::get(); // Assume CTA group size in N less than 32, such that we can finalize with a single warp. static_assert(CTAS_PER_ROW <= 32); // Every warp does the final reduction locally. if( lane_ < CTAS_PER_ROW ) { stats_t result = workspace[lane_]; n = ELTS_PER_ROW_PER_CTA; m = layer_norm::Get<0>::of(result); m2 = layer_norm::Get<1>::of(result); } warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); return { m, m2 }; } InterCTASync inter_cta_; BlockStats block_stats_; stats_t *w0_; stats_t *w1_; int bidn_; int warp_n_; int lane_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { using WarpStats = Stats; using stats_t = typename WarpStats::stats_t; enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; template inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) , use0_(true) { smem0_ = static_cast(smem) + warp_m * WARPS_N; smem1_ = smem0_ + WARPS_M * WARPS_N; } template inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { stats_t * smem = use0_ ? smem0_ : smem1_; use0_ = !use0_; // Compute warp local for all WARPS_N const auto warp_n = warp_stats_.reducer_.warp_n_; const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); stats_t warp_stats = warp_stats_.template compute( elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts ); //Each warp warp leader stores its stats const auto lane = warp_stats_.reducer_.lane_; if( lane == 0 ) { smem[warp_n] = warp_stats; } __syncthreads(); int n = 0;; T m = Zeros::get(); T m2 = Zeros::get(); // Assume that there are less than 32 warps, such that we can finalize with a single warp static_assert(WARPS_N <= 32); if(lane < WARPS_N){ stats_t result = smem[lane]; n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); m = layer_norm::Get<0>::of(result); m2 = layer_norm::Get<1>::of(result); } warp_chan_upd_dynamic(m, m2, n, WARPS_N); return { m, m2 }; } WarpStats warp_stats_; stats_t * smem0_; stats_t * smem1_; bool use0_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Stats { using stats_t = typename TypeToVec2::Type; // The simple Warp reducer. using Reducer = Reducer; enum { SMEM_BYTES = 0 }; template inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) { } template inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { auto sum = Sum(); T m = Zeros::get(); #pragma unroll for( int it = 0; it < N; it++ ) { if (Is_even_cols || (it < num_valid_elts)) { m += elts[it]; } } m = reducer_.allreduce(m, sum) * row_norm_factor; T m2 = Zeros::get(); #pragma unroll for( int it = 0; it < N; it++ ) { if (Is_even_cols || (it < num_valid_elts)) { T diff = (elts[it] - m); m2 += diff * diff; } } m2 = reducer_.allreduce(m2, sum); return {m, m2}; } Reducer reducer_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm ================================================ FILE: csrc/layer_norm/setup.py ================================================ # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py import sys import warnings import os from packaging.version import parse, Version import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from setuptools import setup, find_packages import subprocess # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_version def check_cuda_torch_binary_vs_bare_metal(cuda_dir): raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) torch_binary_version = parse(torch.version.cuda) print("\nCompiling cuda extensions with") print(raw_output + "from " + cuda_dir + "/bin\n") if (bare_metal_version != torch_binary_version): raise RuntimeError( "Cuda extensions are being compiled with a version of Cuda that does " "not match the version used to compile Pytorch binaries. " "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + "In some cases, a minor-version mismatch will not cause later errors: " "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " "You can try commenting out this check (at your own risk)." ) def raise_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return raise RuntimeError( f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " "only images whose names contain 'devel' will provide nvcc." ) def append_nvcc_threads(nvcc_extra_args): _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.2"): nvcc_threads = os.getenv("NVCC_THREADS") or "4" return nvcc_extra_args + ["--threads", nvcc_threads] return nvcc_extra_args if not torch.cuda.is_available(): # https://github.com/NVIDIA/apex/issues/486 # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). print( "\nWarning: Torch did not find available GPUs on this system.\n", "If your intention is to cross-compile, this is not an error.\n" "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" "If you wish to cross-compile for a single specific architecture,\n" 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', ) if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version >= Version("11.8"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" elif bare_metal_version >= Version("11.1"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" elif bare_metal_version == Version("11.0"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) cmdclass = {} ext_modules = [] # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 generator_flag = [] torch_dir = torch.__path__[0] if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): generator_flag = ["-DOLD_GENERATOR_PATH"] raise_if_cuda_home_none("--fast_layer_norm") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.0"): raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above") cc_flag.append("-gencode") cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") if bare_metal_version >= Version("11.8"): cc_flag.append("-gencode") cc_flag.append("arch=compute_90,code=sm_90") ext_modules.append( CUDAExtension( name="dropout_layer_norm", sources=[ "ln_api.cpp", "ln_fwd_256.cu", "ln_bwd_256.cu", "ln_fwd_512.cu", "ln_bwd_512.cu", "ln_fwd_768.cu", "ln_bwd_768.cu", "ln_fwd_1024.cu", "ln_bwd_1024.cu", "ln_fwd_1280.cu", "ln_bwd_1280.cu", "ln_fwd_1536.cu", "ln_bwd_1536.cu", "ln_fwd_2048.cu", "ln_bwd_2048.cu", "ln_fwd_2560.cu", "ln_bwd_2560.cu", "ln_fwd_3072.cu", "ln_bwd_3072.cu", "ln_fwd_4096.cu", "ln_bwd_4096.cu", "ln_fwd_5120.cu", "ln_bwd_5120.cu", "ln_fwd_6144.cu", "ln_bwd_6144.cu", "ln_fwd_7168.cu", "ln_bwd_7168.cu", "ln_fwd_8192.cu", "ln_bwd_8192.cu", "ln_parallel_fwd_256.cu", "ln_parallel_bwd_256.cu", "ln_parallel_fwd_512.cu", "ln_parallel_bwd_512.cu", "ln_parallel_fwd_768.cu", "ln_parallel_bwd_768.cu", "ln_parallel_fwd_1024.cu", "ln_parallel_bwd_1024.cu", "ln_parallel_fwd_1280.cu", "ln_parallel_bwd_1280.cu", "ln_parallel_fwd_1536.cu", "ln_parallel_bwd_1536.cu", "ln_parallel_fwd_2048.cu", "ln_parallel_bwd_2048.cu", "ln_parallel_fwd_2560.cu", "ln_parallel_bwd_2560.cu", "ln_parallel_fwd_3072.cu", "ln_parallel_bwd_3072.cu", "ln_parallel_fwd_4096.cu", "ln_parallel_bwd_4096.cu", "ln_parallel_fwd_5120.cu", "ln_parallel_bwd_5120.cu", "ln_parallel_fwd_6144.cu", "ln_parallel_bwd_6144.cu", "ln_parallel_fwd_7168.cu", "ln_parallel_bwd_7168.cu", "ln_parallel_fwd_8192.cu", "ln_parallel_bwd_8192.cu", ], extra_compile_args={ "cxx": ["-O3"] + generator_flag, "nvcc": append_nvcc_threads( [ "-O3", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", ] + generator_flag + cc_flag ), }, include_dirs=[this_dir], ) ) setup( name="dropout_layer_norm", version="0.1", description="Fused dropout + add + layer norm", ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension} if ext_modules else {}, ) ================================================ FILE: csrc/layer_norm/static_switch.h ================================================ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once /// @param COND - a boolean expression to switch by /// @param CONST_NAME - a name given for the constexpr bool variable. /// @param ... - code to execute for true and false /// /// Usage: /// ``` /// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); /// }); /// ``` #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() ================================================ FILE: examples/inference/README.md ================================================ # Example of LLM inference using FlashAttention Example script of using FlashAttention for inference coming soon. ================================================ FILE: flash_attn/__init__.py ================================================ from pkgutil import extend_path # look for every subdir with flash_attn base name such that fa2 and fa4 can be co-installed __path__ = extend_path(__path__, __name__) __version__ = "2.8.4" from flash_attn.flash_attn_interface import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) ================================================ FILE: flash_attn/bert_padding.py ================================================ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py import torch import torch.nn.functional as F from einops import rearrange, repeat class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. # return input[indices] return torch.gather( rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) ).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype, ) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. # grad_input[indices] = grad_output grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None index_first_axis = IndexFirstAxis.apply class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros( first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype ) # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. output[indices] = values # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) return output @staticmethod def backward(ctx, grad_output): (indices,) = ctx.saved_tensors # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. grad_values = grad_output[indices] # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) return grad_values, None, None index_put_first_axis = IndexPutFirstAxis.apply class IndexFirstAxisResidual(torch.autograd.Function): @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. output = input[indices] # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last # memory format to channel_first. In other words, input might not be contiguous. # If we don't detach, Pytorch complains about output being a view and is being modified inplace return output, input.detach() @staticmethod def backward(ctx, grad_output, grad_residual): (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] assert grad_residual.shape[1:] == other_shape grad_input = grad_residual # grad_input[indices] += grad_output indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) indices = indices.expand_as(grad_output) grad_input.scatter_add_(0, indices, grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None index_first_axis_residual = IndexFirstAxisResidual.apply def unpad_input(hidden_states, attention_mask, unused_mask=None): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. indices: (total_nnz), the indices of masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. """ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch, ) def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): """ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: ``` [ [2, 3, 0, 0, 0, 0], [3, 2, 0, 0, 0, 0], [6, 0, 0, 0, 0, 0] ] ``` , which refers to the 3D-attention mask: ``` [ [ [1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 1, 1, 0, 0], [0, 0, 1, 1, 1, 0], [0, 0, 0, 0, 0, 1] ], [ [1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 1] ], [ [1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1] ] ] ```. Arguments: hidden_states: (batch, seqlen, ...) attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int """ length = attention_mask_in_length.sum(dim=-1) seqlen = attention_mask_in_length.size(-1) attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, ) def pad_input(hidden_states, indices, batch, seqlen): """ Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. batch: int, batch size for the padded sequence. seqlen: int, maximum sequence length for the padded sequence. Return: hidden_states: (batch, seqlen, ...) """ dim = hidden_states.shape[-1] # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) # output[indices] = hidden_states output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) ================================================ FILE: flash_attn/cute/.flake8 ================================================ [flake8] max-line-length = 100 # W503: line break before binary operator ignore = E731, E741, F841, W503 ================================================ FILE: flash_attn/cute/AUTHORS ================================================ Tri Dao Jay Shah Ted Zadouri Markus Hoehnerbach Vijay Thakkar Timmy Liu Driss Guessous Reuben Stern ================================================ FILE: flash_attn/cute/LICENSE ================================================ BSD 3-Clause License Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: flash_attn/cute/MANIFEST.in ================================================ global-exclude *.egg-info/* prune flash_attn_4.egg-info prune flash_attn.egg-info prune build prune dist ================================================ FILE: flash_attn/cute/README.md ================================================ # FlashAttention-4 (CuTeDSL) FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs. ## Installation ```sh pip install flash-attn-4 ``` ## Usage ```python from flash_attn.cute import flash_attn_func, flash_attn_varlen_func out = flash_attn_func(q, k, v, causal=True) ``` ## Development ```sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention pip install -e "flash_attn/cute[dev]" pytest tests/cute/ ``` ================================================ FILE: flash_attn/cute/__init__.py ================================================ """Flash Attention CUTE (CUDA Template Engine) implementation.""" from importlib.metadata import PackageNotFoundError, version try: __version__ = version("fa4") except PackageNotFoundError: __version__ = "0.0.0" import cutlass.cute as cute from .interface import ( flash_attn_func, flash_attn_varlen_func, ) from flash_attn.cute.cute_dsl_utils import cute_compile_patched # Patch cute.compile to optionally dump SASS cute.compile = cute_compile_patched __all__ = [ "flash_attn_func", "flash_attn_varlen_func", ] ================================================ FILE: flash_attn/cute/ampere_helpers.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Type, Callable, Optional import cutlass import cutlass.cute as cute def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout: dtype_byte = cutlass.const_expr(dtype.width // 8) bytes_per_row = cutlass.const_expr(k_dim * dtype_byte) smem_k_block_size = ( cutlass.const_expr( 128 if bytes_per_row % 128 == 0 else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16)) ) // dtype_byte ) swizzle_bits = ( 4 if smem_k_block_size == 128 else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1)) ) swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4) return cute.make_composed_layout( cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base), 0, cute.make_ordered_layout( (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0) ), ) @cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, tCsA: cute.Tensor, tCsB: cute.Tensor, smem_thr_copy_A: cute.TiledCopy, smem_thr_copy_B: cute.TiledCopy, hook_fn: Optional[Callable] = None, A_in_regs: cutlass.Constexpr[bool] = False, B_in_regs: cutlass.Constexpr[bool] = False, swap_AB: cutlass.Constexpr[bool] = False, ) -> None: if cutlass.const_expr(swap_AB): gemm( tiled_mma, acc, tCrB, tCrA, tCsB, tCsA, smem_thr_copy_B, smem_thr_copy_A, hook_fn, A_in_regs=B_in_regs, B_in_regs=A_in_regs, swap_AB=False, ) else: tCrA_copy_view = smem_thr_copy_A.retile(tCrA) tCrB_copy_view = smem_thr_copy_B.retile(tCrB) if cutlass.const_expr(not A_in_regs): cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0]) if cutlass.const_expr(not B_in_regs): cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])): if k < cute.size(tCsA.shape[2]) - 1: if cutlass.const_expr(not A_in_regs): cute.copy( smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1] ) if cutlass.const_expr(not B_in_regs): cute.copy( smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1] ) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): hook_fn() @cute.jit def gemm_rs( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, tCsB: cute.Tensor, smem_thr_copy_B: cute.TiledCopy, hook_fn: Optional[Callable] = None, ) -> None: tCrB_copy_view = smem_thr_copy_B.retile(tCrB) cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0]) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1): cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]) cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc) if cutlass.const_expr(k == 0 and hook_fn is not None): hook_fn() ================================================ FILE: flash_attn/cute/barrier.py ================================================ import cutlass import cutlass.cute as cute from cutlass import Int32 from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import llvm @dsl_user_op def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() state = llvm.inline_asm( T.i32(), [lock_ptr_i64], "ld.global.acquire.gpu.b32 $0, [$1];", "=r,l", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) return cutlass.Int32(state) @dsl_user_op def red_relaxed( lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None ) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], "red.relaxed.gpu.global.add.s32 [$0], $1;", "l,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @dsl_user_op def red_release( lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None ) -> None: lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)], "red.release.gpu.global.add.s32 [$0], $1;", "l,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: read_val = Int32(0) while read_val != val: read_val = ld_acquire(flag_ptr) @cute.jit def arrive_inc( lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32] ) -> None: flag_ptr = lock_ptr + flag_offset if thread_idx == 0: red_release(flag_ptr, val) # red_relaxed(flag_ptr, val) ================================================ FILE: flash_attn/cute/bench_utils.py ================================================ """Shared benchmark utilities: attention_ref, cuDNN helpers, flops calculation.""" import math import torch try: import cudnn except ImportError: cudnn = None # ── FLOPS calculation ──────────────────────────────────────────────────────── def flops( batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) ): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: if window_size == (None, None): avg_seqlen = seqlen_k else: row_idx = torch.arange(seqlen_q, device="cuda") col_left = ( torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) if window_size[0] is not None else torch.zeros_like(row_idx) ) col_right = ( torch.minimum( row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1) ) if window_size[1] is not None else torch.full_like(row_idx, seqlen_k - 1) ) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) # ── Reference attention ───────────────────────────────────────────────────── _attention_ref_mask_cache = {} def attention_ref(q, k, v, causal=False): """Standard attention reference implementation. Args: q, k, v: (batch, seqlen, nheads, headdim) tensors. causal: whether to apply causal mask. """ softmax_scale = 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) if causal: if scores.shape[-2] not in _attention_ref_mask_cache: mask = torch.tril( torch.ones(scores.shape[-2:], device=scores.device, dtype=torch.bool), diagonal=0 ) _attention_ref_mask_cache[scores.shape[-2]] = mask else: mask = _attention_ref_mask_cache[scores.shape[-2]] scores = scores.masked_fill(mask, float("-inf")) attn = torch.softmax(scores, dim=-1) return torch.einsum("bhts,bshd->bthd", attn, v) # ── cuDNN graph helpers ───────────────────────────────────────────────────── _TORCH_TO_CUDNN_DTYPE = { torch.float16: "HALF", torch.bfloat16: "BFLOAT16", torch.float32: "FLOAT", torch.int32: "INT32", torch.int64: "INT64", } def _build_cudnn_graph(io_dtype, tensors, build_fn): """Build a cuDNN graph. Returns (graph, variant_pack, workspace).""" assert cudnn is not None, "cuDNN is not available" cudnn_dtype = getattr(cudnn.data_type, _TORCH_TO_CUDNN_DTYPE[io_dtype]) graph = cudnn.pygraph( io_data_type=cudnn_dtype, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) graph_tensors = {name: graph.tensor_like(t.detach()) for name, t in tensors.items()} variant_pack = build_fn(graph, graph_tensors) graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) return graph, variant_pack, workspace def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None): """Build a cuDNN forward SDPA graph. Args: q, k, v: (batch, nheads, seqlen, headdim) tensors (cuDNN layout). causal: whether to apply causal mask. window_size_left: sliding window size (None for no window). Returns: (fwd_fn, o_gpu, stats_gpu) where fwd_fn is a zero-arg callable. """ b, nheads, seqlen_q, headdim = q.shape headdim_v = v.shape[-1] o_gpu = torch.empty(b, nheads, seqlen_q, headdim_v, dtype=q.dtype, device=q.device) stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) def build(graph, gt): o, stats = graph.sdpa( name="sdpa", q=gt["q"], k=gt["k"], v=gt["v"], is_inference=False, attn_scale=1.0 / math.sqrt(headdim), use_causal_mask=causal or window_size_left is not None, sliding_window_length=window_size_left if window_size_left is not None and not causal else None, ) o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) return {gt["q"]: q, gt["k"]: k, gt["v"]: v, o: o_gpu, stats: stats_gpu} graph, variant_pack, workspace = _build_cudnn_graph(q.dtype, {"q": q, "k": k, "v": v}, build) def fwd_fn(): graph.execute(variant_pack, workspace) return o_gpu return fwd_fn, o_gpu, stats_gpu def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=None): """Build a cuDNN backward SDPA graph. Args: q, k, v, o, g, lse: (batch, nheads, seqlen, dim) tensors (cuDNN layout). causal: whether to apply causal mask. window_size_left: sliding window size (None for no window). Returns: bwd_fn: zero-arg callable that returns (dq, dk, dv). """ headdim = q.shape[-1] dq_gpu, dk_gpu, dv_gpu = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) def build(graph, gt): dq, dk, dv = graph.sdpa_backward( name="sdpa_backward", q=gt["q"], k=gt["k"], v=gt["v"], o=gt["o"], dO=gt["g"], stats=gt["lse"], attn_scale=1.0 / math.sqrt(headdim), use_causal_mask=causal or window_size_left is not None, sliding_window_length=window_size_left if window_size_left is not None and not causal else None, use_deterministic_algorithm=False, ) dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) return { gt["q"]: q, gt["k"]: k, gt["v"]: v, gt["o"]: o, gt["g"]: g, gt["lse"]: lse, dq: dq_gpu, dk: dk_gpu, dv: dv_gpu, } graph, variant_pack, workspace = _build_cudnn_graph( q.dtype, {"q": q, "k": k, "v": v, "o": o, "g": g, "lse": lse}, build, ) def bwd_fn(): graph.execute(variant_pack, workspace) return dq_gpu, dk_gpu, dv_gpu return bwd_fn ================================================ FILE: flash_attn/cute/benchmark.py ================================================ # Copyright (c) 2023, Tri Dao. """Useful functions for writing test code.""" import torch import torch.utils.benchmark as benchmark def benchmark_forward( fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs ): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" if verbose: print(desc, "- Forward pass") def amp_wrapper(*inputs, **kwinputs): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): fn(*inputs, **kwinputs) t = benchmark.Timer( stmt="fn_amp(*inputs, **kwinputs)", globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_backward( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" if verbose: print(desc, "- Backward pass") with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError("Grad shape does not match output shape") def f(*inputs, y, grad): # Set .grad to None to avoid extra operation of gradient accumulation for x in inputs: if isinstance(x, torch.Tensor): x.grad = None y.backward(grad, retain_graph=True) t = benchmark.Timer( stmt="f(*inputs, y=y, grad=grad)", globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_combined( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" if verbose: print(desc, "- Forward + Backward pass") with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError("Grad shape does not match output shape") def f(grad, *inputs, **kwinputs): for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] y.backward(grad, retain_graph=True) t = benchmark.Timer( stmt="f(grad, *inputs, **kwinputs)", globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_fwd_bwd( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" return ( benchmark_forward( fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_backward( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), ) def benchmark_all( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" return ( benchmark_forward( fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_backward( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_combined( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), ) def pytorch_profiler( fn, *inputs, trace_filename=None, backward=False, amp=False, amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs, ): """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" if backward: with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] g = torch.randn_like(out) for _ in range(30): # Warm up if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] # Backward should be done outside autocast if backward: out.backward(g, retain_graph=True) activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ torch.profiler.ProfilerActivity.CUDA ] with torch.profiler.profile( activities=activities, record_shapes=True, # profile_memory=True, with_stack=True, ) as prof: if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] if backward: out.backward(g, retain_graph=True) if verbose: # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) print(prof.key_averages().table(row_limit=50)) if trace_filename is not None: prof.export_chrome_trace(trace_filename) def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() fn(*inputs, **kwinputs) torch.cuda.synchronize() mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) if verbose: print(f"{desc} max memory: {mem}GB") torch.cuda.empty_cache() return mem ================================================ FILE: flash_attn/cute/blackwell_helpers.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Optional, Tuple import cutlass import cutlass.cute as cute from cutlass import Int32, Boolean, const_expr from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm import flash_attn.cute.mma_sm100_desc as sm100_desc @cute.jit def gemm_w_idx( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, zero_init: bool | Boolean = False, swap_AB: bool = False, num_unroll_groups: int = 1, ) -> None: if const_expr(swap_AB): return gemm_w_idx( tiled_mma, acc, tCrB, tCrA, B_idx, A_idx, zero_init=zero_init, swap_AB=False ) else: rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] mma_atom = cute.make_mma_atom(tiled_mma.op) for k in cutlass.range( cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups ): mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc) @cute.jit def gemm_ptx_w_idx( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, A_idx: Optional[Int32] = None, B_idx: Optional[Int32] = None, zero_init: bool | Boolean = False, cta_group: int = 1, **kwargs, ) -> None: rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] sA_cur = None if const_expr(sA is not None): sA_cur = sA if const_expr(A_idx is None) else sA[None, None, None, A_idx] sB_cur = sB if const_expr(B_idx is None) else sB[None, None, None, B_idx] mma_atom = cute.make_mma_atom(tiled_mma.op) acc_tmem_addr = acc.iterator.toint() gemm_ptx_partial( mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, cta_group=cta_group, **kwargs, ) @cute.jit def gemm( tiled_mma: cute.TiledMma, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, zero_init: bool | Boolean = False, ) -> None: mma_atom = cute.make_mma_atom(tiled_mma.op) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0) cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) def i64_to_i32x2(i: int) -> Tuple[int, int]: """Convert a 64-bit integer to a tuple of two 32-bit integers.""" return i & 0xFFFF_FFFF, (i >> 32) & 0xFFFF_FFFF @cute.jit def gemm_ptx( op: cute.nvgpu.tcgen05.mma.MmaOp, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) smem_desc_b_hi = const_expr(smem_desc_b_hi) if const_expr(not is_ts): smem_desc_start_a_lo = Int32(smem_desc_base_a_lo) | sm100_desc.make_smem_desc_start_addr( sA[None, None, 0].iterator ) else: smem_desc_start_a_lo = None smem_desc_start_b_lo = Int32(smem_desc_base_b_lo) | sm100_desc.make_smem_desc_start_addr( sB[None, None, 0].iterator ) for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): if const_expr(not is_ts): smem_desc_a_lo = smem_desc_start_a_lo + ( (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 ) smem_desc_b_lo = smem_desc_start_b_lo + ( (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 ) # with cute.arch.elect_one(): # cute.printf("smem_desc_a_lo = {}, smem_desc_b_lo = {}", smem_desc_a_lo, smem_desc_b_lo) # cute.printf("smem_desc_a_lo_correct = {}, smem_desc_b_lo_correct = {}", smem_desc_a_lo_correct, smem_desc_b_lo_correct) with cute.arch.elect_one(): if const_expr(not is_ts): llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), smem_desc_a_lo.ir_value(), smem_desc_b_lo.ir_value(), Int32(not zero_init or k != 0).ir_value(), ], "{\n\t" ".reg .pred p;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" ".reg .b32 idesc;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) else: llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), tCrA[None, None, k].iterator.toint().ir_value(), smem_desc_b_lo.ir_value(), Int32(not zero_init or k != 0).ir_value(), ], "{\n\t" ".reg .pred p;\n\t" ".reg .b64 smem_desc_b;\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def gemm_ptx_loop( op: cute.nvgpu.tcgen05.mma.MmaOp, acc: cute.Tensor, tCrA: cute.Tensor, tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) smem_desc_b_hi = const_expr(smem_desc_b_hi) if const_expr(not is_ts): offset_a = [ (cute.crd2idx((0, 0, k), sA_layout) * sA.element_type.width // 8) >> 4 for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) ] else: offset_a = [ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])) ] offset_a_diff = [ offset_a[k] - offset_a[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ] offset_b = [ (cute.crd2idx((0, 0, k), sB_layout) * sB.element_type.width // 8) >> 4 for k in cutlass.range_constexpr(cute.size(tCrB.shape[2])) ] offset_b_diff = [ offset_b[k] - offset_b[k - 1] for k in cutlass.range_constexpr(1, cute.size(tCrB.shape[2])) ] if const_expr(not is_ts): smem_desc_start_a_lo = Int32( smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) ) else: smem_desc_start_a_lo = None smem_desc_start_b_lo = Int32( smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) ) pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" if const_expr(not is_ts): llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" "mov.b32 smem_desc_a_lo, $1;\n\t" "mov.b32 smem_desc_b_lo, $2;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) else: llvm.inline_asm( None, [ acc.iterator.toint().ir_value(), Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), Int32(smem_desc_start_b_lo).ir_value(), Int32(not zero_init).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_a;\n\t" ".reg .b32 smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" "mov.b32 tmem_a, $1;\n\t" "mov.b32 smem_desc_b_lo, $2;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def gemm_ptx_partial( op: cute.nvgpu.tcgen05.mma.MmaOp, acc_tmem_addr: Int32, tCrA: cute.Tensor, tCrB: cute.Tensor, sA: Optional[cute.Tensor], sB: cute.Tensor, mbar_ptr: Optional[cutlass.Pointer] = None, mbar_phase: Optional[Int32] = None, split_arrive: Optional[int] = None, zero_init: bool | Boolean = False, # sA_offset: Int32 = 0, # acc_offset: Int32 = 0, tA_addr: Optional[Int32] = None, cta_group: int = 1, ) -> None: # acc_tmem_addr += acc_offset is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if const_expr(not is_ts): assert sA is not None, "sA must be provided when a_src is not TMEM" sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None sB_swizzle = sB.iterator.type.swizzle_type smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) smem_desc_b_hi = const_expr(smem_desc_b_hi) tCrA_layout = ( tCrA.layout if const_expr(not is_ts) else cute.recast_layout(32, tCrA.element_type.width, tCrA.layout) ) offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(cute.size(tCrA.shape[2]))] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] offset_b = [cute.crd2idx((0, 0, k), tCrB.layout) for k in range(cute.size(tCrB.shape[2]))] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] if const_expr(not is_ts): smem_desc_start_a_lo = Int32( smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator) ) # ) + sA_offset else: smem_desc_start_a_lo = None smem_desc_start_b_lo = Int32( smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator) ) pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" if const_expr(not is_ts): assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_acc, $3;\n\t" "mov.b32 smem_desc_a_lo_start, $0;\n\t" "mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) + "}\n", # "r,r,r", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) else: # For TS gemm, somehow tCrA.iterator.toint() returns 0 no matter what, so we need to # explicitly pass in the tA_addr for correctness. tA_addr = tCrA[None, None, 0].iterator.toint() if tA_addr is None else tA_addr input_args = [ # Int32(cute.arch.make_warp_uniform(tCrA[None, None, 0].iterator.toint())).ir_value(), Int32(cute.arch.make_warp_uniform(tA_addr)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ] if const_expr(mbar_ptr is not None): assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" assert split_arrive is not None, ( "split_arrive must be provided when mbar_ptr is not None" ) split_arrive_idx = split_arrive // op.shape_mnk[2] input_args.append(mbar_ptr.toint().ir_value()) input_args.append(Int32(mbar_phase).ir_value()) mbar_wait_str = ( ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" "@P1 bra DONE; \n\t" "bra LAB_WAIT; \n\t" "DONE: \n\t" ) else: mbar_wait_str = "" llvm.inline_asm( None, # [ # # acc.iterator.toint().ir_value(), # Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), # Int32(smem_desc_start_b_lo).ir_value(), # Int32(not zero_init).ir_value(), # ], input_args, "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 tmem_a;\n\t" ".reg .b32 smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_acc, $3;\n\t" f"mov.b32 tmem_a, $0;\n\t" f"mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, cute.size(tCrA.shape[2]) if const_expr(mbar_ptr is None) else split_arrive_idx, ) ) + mbar_wait_str + ( "".join( ( f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) ) if const_expr(mbar_ptr is not None) else "" ) + "}\n", "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def gemm_ptx_partial1( op: cute.nvgpu.tcgen05.mma.MmaOp, acc_tmem_addr: cutlass.Constexpr[int], tCrA: cute.Tensor, tCrB: cute.Tensor, sA_base_addr_for_desc: Int32, sA_addr_offset_for_desc: cutlass.Constexpr[int], sA_stage: Int32, sB_base_addr_for_desc: Int32, sB_addr_offset_for_desc: cutlass.Constexpr[int], sB_stage: Int32, sA_layout: Optional[cute.Layout], sB_layout: Optional[cute.Layout], sA_swizzle: Optional[cute.Swizzle], sB_swizzle: cute.Swizzle, zero_init: bool | Boolean = False, ) -> None: is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM if const_expr(not is_ts): assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) if const_expr(not is_ts): smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.a_dtype.width, sA_layout[0]), sA_swizzle, sm100_desc.Major.K if const_expr(op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) smem_desc_base_a_lo = const_expr(smem_desc_base_a_lo) smem_desc_a_hi = const_expr(smem_desc_a_hi) else: smem_desc_base_a = None smem_desc_base_a_lo, smem_desc_a_hi = None, None smem_desc_base_b: int = const_expr( sm100_desc.make_smem_desc_base( cute.recast_layout(128, op.b_dtype.width, sB_layout[0]), sB_swizzle, sm100_desc.Major.K if const_expr(op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K) else sm100_desc.Major.MN, ) ) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) smem_desc_base_b_lo = const_expr(smem_desc_base_b_lo) smem_desc_b_hi = const_expr(smem_desc_b_hi) mask = [Int32(0)] * 4 if const_expr(not is_ts): offset_a = [ (cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 8) >> 4 for k in range(cute.size(tCrA.shape[2])) ] else: offset_a = [ cute.crd2idx((0, 0, k), sA_layout) * op.a_dtype.width // 32 for k in range(cute.size(tCrA.shape[2])) ] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, cute.size(tCrA.shape[2]))] offset_b = [ (cute.crd2idx((0, 0, k), sB_layout) * op.b_dtype.width // 8) >> 4 for k in range(cute.size(tCrB.shape[2])) ] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, cute.size(tCrB.shape[2]))] if const_expr(not is_ts): # smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | sm100_desc.make_smem_desc_start_addr(sA[None, None, 0].iterator)) smem_desc_start_a_lo = const_expr(smem_desc_base_a_lo) else: smem_desc_start_a_lo = None # smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | sm100_desc.make_smem_desc_start_addr(sB[None, None, 0].iterator)) smem_desc_start_b_lo = const_expr(smem_desc_base_b_lo) pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" if const_expr(not is_ts): llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), # Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), Int32(sA_base_addr_for_desc).ir_value(), Int32(sA_stage).ir_value(), # Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(sB_base_addr_for_desc).ir_value(), Int32(sB_stage).ir_value(), Int32(not zero_init).ir_value(), mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), mask[3].ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" # "mov.b32 smem_desc_a_lo, $0;\n\t" # f"add.u32 smem_desc_a_lo, $0, {hex(smem_desc_start_a_lo)};\n\t" f"mad.lo.u32 smem_desc_a_lo, $1, {hex(sA_addr_offset_for_desc)}, $0;\n\t" # "mov.b32 smem_desc_b_lo, $2;\n\t" f"mad.lo.u32 smem_desc_b_lo, $3, {hex(sB_addr_offset_for_desc)}, $2;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $4, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r,r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) else: llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), Int32(tCrA[None, None, 0].iterator.toint()).ir_value(), Int32(smem_desc_start_b_lo).ir_value(), Int32(not zero_init).ir_value(), mask[0].ir_value(), mask[1].ir_value(), mask[2].ir_value(), mask[3].ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_a;\n\t" ".reg .b32 smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" f"mov.b32 tmem_a, $1;\n\t" f"mov.b32 smem_desc_b_lo, $2;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + "".join( ( f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) + "}\n", "r,r,r,r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def gemm_ptx_precomputed( acc_tmem_addr: Int32, smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A smem_desc_start_b: Int32, idesc: int, smem_desc_base_a: Optional[int], smem_desc_base_b: int, tCrA_layout: cute.Layout, tCrB_layout: cute.Layout, mbar_ptr: Optional[cutlass.Pointer] = None, mbar_phase: Optional[Int32] = None, zero_init: bool | Boolean = False, cta_group: int = 1, ) -> None: # acc_tmem_addr += acc_offset is_ts = const_expr(smem_desc_base_a is None) num_k_tile = cute.size(tCrA_layout.shape[2]) if const_expr(not is_ts): smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) else: smem_desc_base_a_lo, smem_desc_a_hi = None, None smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) tCrA_layout = ( tCrA_layout if const_expr(not is_ts) # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) # currently hard-coding the width to 16 else cute.recast_layout(32, 16, tCrA_layout) ) offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] offset_a_diff = [offset_a[k] - offset_a[k - 1] for k in range(1, num_k_tile)] offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] offset_b_diff = [offset_b[k] - offset_b[k - 1] for k in range(1, num_k_tile)] smem_desc_start_a_lo = None if const_expr(not is_ts): smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) # smem_desc_start_a_lo = smem_desc_start_a smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" if const_expr(not is_ts): assert mbar_ptr is None, "mbar_ptr must be None when a_src is not TMEM" llvm.inline_asm( None, [ # acc.iterator.toint().ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 smem_desc_a_lo_start, smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_a, smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_acc, $3;\n\t" "mov.b32 smem_desc_a_lo_start, $0;\n\t" "mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_a_hi, {hex(smem_desc_a_hi)};\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.s32 smem_desc_a_lo, smem_desc_a_lo_start, {hex(offset_a[k])};\n\t" f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, num_k_tile) ) + "}\n", # "r,r,r", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) else: input_args = [ Int32(cute.arch.make_warp_uniform(smem_desc_start_a)).ir_value(), Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ] if const_expr(mbar_ptr is not None): assert mbar_phase is not None, "mbar_phase must be provided when mbar_ptr is not None" input_args.append(mbar_ptr.toint().ir_value()) input_args.append(Int32(mbar_phase).ir_value()) mbar_wait_str = ( ".reg .pred P1; \n\t" "LAB_WAIT: \n\t" "mbarrier.try_wait.parity.shared::cta.b64 P1, [$4], $5, 10000000; \n\t" "@P1 bra DONE; \n\t" "bra LAB_WAIT; \n\t" "DONE: \n\t" ) else: mbar_wait_str = "" llvm.inline_asm( None, # [ # # acc.iterator.toint().ir_value(), # Int32(tCrA_layout[None, None, 0].iterator.toint()).ir_value(), # Int32(smem_desc_start_b_lo).ir_value(), # Int32(not zero_init).ir_value(), # ], input_args, "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 tmem_a;\n\t" ".reg .b32 smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_b_hi;\n\t" ".reg .b64 smem_desc_b;\n\t" "elect.sync _|leader_thread, -1;\n\t" f"mov.b32 idesc, {hex(idesc)};\n\t" # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_acc, $3;\n\t" f"mov.b32 tmem_a, $0;\n\t" f"mov.b32 smem_desc_b_lo_start, $1;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, num_k_tile if const_expr(mbar_ptr is None) else num_k_tile // 4 * 3, ) ) + mbar_wait_str + ( "".join( ( # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(num_k_tile // 4 * 3, num_k_tile) ) if const_expr(mbar_ptr is not None) else "" ) + "}\n", "r,r,r,r" if const_expr(mbar_ptr is None) else "r,r,r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def declare_ptx_smem_desc( smem_desc_start_a: Int32, # If TS, then this is the tmem start address for A smem_desc_base_a: Optional[int], tCrA_layout: cute.Layout, var_name_prefix: str = "smem_desc", ) -> None: is_ts = const_expr(smem_desc_base_a is None) num_k_tile = cute.size(tCrA_layout.shape[2]) smem_desc_base_a_lo, smem_desc_a_hi = None, None if const_expr(not is_ts): smem_desc_base_a_lo, smem_desc_a_hi = i64_to_i32x2(smem_desc_base_a) tCrA_layout = ( tCrA_layout if const_expr(not is_ts) # else cute.recast_layout(32, tCrA.element_type.width, tCrA_layout) # currently hard-coding the width to 16 else cute.recast_layout(32, 16, tCrA_layout) ) offset_a = [cute.crd2idx((0, 0, k), tCrA_layout) for k in range(num_k_tile)] smem_desc_start_a_lo = None if const_expr(not is_ts): smem_desc_start_a_lo = Int32(smem_desc_base_a_lo | smem_desc_start_a) if const_expr(not is_ts): llvm.inline_asm( None, [Int32(cute.arch.make_warp_uniform(smem_desc_start_a_lo)).ir_value()], f".reg .b32 {var_name_prefix}_lo;\n\t" f".reg .b64 {var_name_prefix}_<{num_k_tile}>;\n\t" f"mov.b64 {var_name_prefix}_0, {{$0, {hex(smem_desc_a_hi)}}};\n\t" + "".join( ( f"add.s32 {var_name_prefix}_lo, $0, {hex(offset_a[k])};\n\t" f"mov.b64 {var_name_prefix}_{k}, {{{var_name_prefix}_lo, {hex(smem_desc_a_hi)}}};\n\t" ) for k in range(1, num_k_tile) ), "r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = "idesc") -> None: idesc = const_expr(sm100_desc.mma_op_to_idesc(op)) llvm.inline_asm( None, [], f".reg .b32 {var_name};\n\t" # noqa f"mov.b32 {var_name}, {hex(idesc)};\n\t", constraints="", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @cute.jit def gemm_ptx_precomputed_varname( acc_tmem_addr: Int32, smem_desc_start_b: Int32, # idesc: int, smem_desc_base_b: int, tCrB_layout: cute.Layout, smem_var_name_prefix: str, idesc_var_name: str, smem_offset: int, zero_init: bool | Boolean = False, cta_group: int = 1, ) -> None: is_ts = False num_k_tile = cute.size(tCrB_layout.shape[2]) smem_desc_base_b_lo, smem_desc_b_hi = i64_to_i32x2(smem_desc_base_b) offset_b = [cute.crd2idx((0, 0, k), tCrB_layout) for k in range(num_k_tile)] smem_desc_start_b_lo = Int32(smem_desc_base_b_lo | smem_desc_start_b) pred_str = "p" if isinstance(zero_init, Boolean) else "0" if zero_init else "1" if const_expr(not is_ts): llvm.inline_asm( None, [ Int32(cute.arch.make_warp_uniform(smem_desc_start_b_lo)).ir_value(), Int32(not zero_init).ir_value(), Int32(cute.arch.make_warp_uniform(acc_tmem_addr)).ir_value(), ], "{\n\t" ".reg .pred leader_thread;\n\t" ".reg .pred p;\n\t" # ".reg .b32 idesc;\n\t" ".reg .b32 tmem_acc;\n\t" ".reg .b32 smem_desc_b_lo_start;\n\t" ".reg .b32 smem_desc_a_lo, smem_desc_b_lo;\n\t" ".reg .b32 smem_desc_a_hi, smem_desc_b_hi;\n\t" # ".reg .b64 smem_desc_b;\n\t" f".reg .b64 smem_desc_b_<{num_k_tile}>;\n\t" "elect.sync _|leader_thread, -1;\n\t" # f"mov.b32 idesc, {hex(idesc)};\n\t" # f"mov.b32 tmem_acc, {hex(acc_tmem_addr)};\n\t" f"mov.b32 tmem_acc, $2;\n\t" "mov.b32 smem_desc_b_lo_start, $0;\n\t" f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_0;\n\t" f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" f"mov.b64 {smem_var_name_prefix}_0, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b_0, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" + "".join( ( f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b_{k}, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" ) for k in range(1, num_k_tile) ) + "setp.ne.b32 p, $1, 0;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + "".join( ( # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" # f"add.s32 smem_desc_a_lo, smem_desc_a_lo, {smem_offset};\n\t" # f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" # f"mov.b64 {smem_var_name_prefix}_{k}, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" ) for k in range(1, num_k_tile) ) + "}\n", "r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ================================================ FILE: flash_attn/cute/block_info.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. from typing import Tuple, Optional from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass import Int32, const_expr from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) class BlockInfo: tile_m: cutlass.Constexpr[int] tile_n: cutlass.Constexpr[int] is_causal: cutlass.Constexpr[bool] is_local: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 @cute.jit def get_n_block_min_max( self, seqlen_info: SeqlenInfoQK, m_block: Int32, split_idx: Int32 = 0, num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n) if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): m_idx_max = (m_block + 1) * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_right = n_idx if const_expr(self.is_causal) else n_idx + self.window_size_right n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n)) n_block_min = 0 if const_expr(self.is_local and self.window_size_left is not None): m_idx_min = m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left n_block_min = cutlass.max(n_idx_left // self.tile_n, 0) if cutlass.const_expr(self.is_split_kv): num_n_blocks_per_split = ( Int32(0) if n_block_max <= n_block_min else (n_block_max - n_block_min + num_splits - 1) // num_splits ) n_block_min = n_block_min + split_idx * num_n_blocks_per_split n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max) return n_block_min, n_block_max @cute.jit def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int32) -> Tuple[Int32, Int32]: m_block_max = cute.ceil_div(seqlen_info.seqlen_q, self.tile_m) m_block_min = 0 if const_expr(self.is_causal or (self.is_local and self.window_size_right is not None)): n_idx_min = n_block * self.tile_n m_idx = n_idx_min + seqlen_info.seqlen_q - seqlen_info.seqlen_k m_idx_right = m_idx if const_expr(self.is_causal) else m_idx - self.window_size_right m_block_min = max(m_block_min, m_idx_right // self.tile_m) if const_expr(self.is_local and self.window_size_left is not None): n_idx_max = (n_block + 1) * self.tile_n m_idx = n_idx_max + seqlen_info.seqlen_q - seqlen_info.seqlen_k m_idx_left = m_idx + self.window_size_left m_block_max = min(m_block_max, cute.ceil_div(m_idx_left, self.tile_m)) return m_block_min, m_block_max @cute.jit def get_n_block_k_new_min_max( self, seqlen_info: SeqlenInfoQKNewK, m_block: Int32, split_idx: Int32 = 0, num_splits: Int32 = 1, ) -> Tuple[Int32, Int32]: """Get the block range for new K tokens (append KV). First computes the full n_block range via get_n_block_min_max, then maps those blocks into the new-K index space by subtracting seqlen_k_og. """ n_block_min, n_block_max = self.get_n_block_min_max( seqlen_info, m_block, split_idx, num_splits, ) idx_k_new_min = cutlass.max(n_block_min * self.tile_n - seqlen_info.seqlen_k_og, 0) idx_k_new_max = cutlass.min( n_block_max * self.tile_n - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new ) n_block_new_min = idx_k_new_min // self.tile_n n_block_new_max = ( cute.ceil_div(idx_k_new_max, self.tile_n) if idx_k_new_max > idx_k_new_min else n_block_new_min ) return n_block_new_min, n_block_new_max @cute.jit def get_n_block_min_causal_local_mask( self, seqlen_info: SeqlenInfoQK, m_block: Int32, n_block_min: Int32, ) -> Int32: """If we have separate iterations with causal or local masking at the start, where do we stop""" m_idx_min = m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_min = m_idx_min // self.qhead_per_kvhead_packgqa n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_right = ( n_idx if const_expr(not self.is_local or self.window_size_right is None) else n_idx + self.window_size_right ) return cutlass.max(n_block_min, n_idx_right // self.tile_n) @cute.jit def get_n_block_min_before_local_mask( self, seqlen_info: SeqlenInfoQK, m_block: Int32, n_block_min: Int32, ) -> Int32: """If we have separate iterations with local masking at the end, where do we stop the non-masked iterations""" if const_expr(not self.is_local or self.window_size_left is None): return n_block_min else: m_idx_max = (m_block + 1) * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa > 1): m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) ================================================ FILE: flash_attn/cute/block_sparse_utils.py ================================================ """ Block-sparse runtime utilities for CUTE DSL kernels. This module contains runtime execution functions for block-sparse attention kernels. These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ from typing import Callable, Optional from functools import partial import math import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from quack import copy_utils # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] # # For block-sparse SM100 forward, a given (m_block, stage) Q tile can have zero active # KV blocks (total_block_cnt == 0). In that case there is no seqlen_kv iteration, so # the softmax warp-group has no row stats to publish. # # The correction warp-group seeds fully-masked-row stats and runs the usual correction # epilogue so output/LSE have well-defined values. Both warp-groups must still perform # the softmax<->correction mbarrier handshake so phases advance correctly across # empty->empty and empty->non-empty tile sequences. # # In the no-sink case, this corresponds to the usual fully-masked-row convention: # output is zero and LSE is -inf. # # Barrier contract (each is `mbar_ptr + + stage`): # # Producer/consumer pairs: # - `mbar_softmax_corr_full` : softmax arrive -> correction wait # - `mbar_softmax_corr_empty` : correction arrive -> softmax wait # - `mbar_P_full_O_rescaled` : softmax arrive (+ correction arrive) -> MMA wait # - `mbar_P_full_2` : softmax arrive -> MMA wait # - `mbar_corr_epi_full_/empty` : correction <-> epilogue (only when epilogue is separate) # # Empty tile (`total_block_cnt == 0`): # - Softmax: skips the seqlen_kv softmax path entirely (no P stores, no `mbar_P_full_*`). # It only arrives `mbar_softmax_corr_full` once per stage as a synthetic "no work" signal. # At the `softmax_loop` level, softmax unconditionally waits `mbar_softmax_corr_empty` # before each tile (when block-sparse) to drain a prior correction arrival and keep # phases aligned across non-empty -> empty transitions. # - Correction: waits `mbar_softmax_corr_full`, seeds stats + runs `correction_epilogue(scale=0)`, # and arrives `mbar_softmax_corr_empty` (and `mbar_corr_epi_full_/empty` when applicable). # - No `mbar_P_full_*` barriers are arrived (no P, no MMA O); only the softmax<->correction # (and correction<->epilogue) handshakes advance phases. # # Non-empty tile: # - Softmax: runs `softmax_step` (produces P) and uses `mbar_softmax_corr_full/empty` to # publish row_max (during seqlen_kv) and final row stats (once per tile), and to advance phases; # arrives `mbar_P_full_*` when P is stored. # - Correction: waits `mbar_softmax_corr_full`, may rescale/release O, arrives `mbar_softmax_corr_empty` # to ack/advance, and arrives `mbar_P_full_O_rescaled` when MMA can proceed. # # Backward (SM100): # - Empty KV tile: for a given `n_block`, `total_m_block_cnt == 0` means no Q tiles contribute. # - Both the load and compute loops guard all pipeline work on `process_tile`, so empty tiles # skip producer/consumer operations entirely (no per-tile mbarrier phase handshake like forward). # - In the `not dKV_postprocess` path, dK/dV for empty KV tiles are explicitly written as zeros # even when `process_tile == False` (see `flash_bwd_sm100.py` `should_zero_dKV`). @cute.jit def load_block_list( block_indices: cute.Tensor, block_count, first_block_preloaded: cutlass.Constexpr, kv_producer_state, load_K, load_V, pipeline_k, pipeline_v, intra_wg_overlap: cutlass.Constexpr, ): """Iterate over the sparse blocks and load K, V into the pipeline. For the intra_wg_overlap case, we overlap the loads of K and V. And this means we need to pipeline the last V load from the partial block case, with the loads for the full blocks. Set first_block_preloaded when the caller has already issued the first K load for the list. Q is loaded separately on its own mbarrier before this function is called. Note: we iterate along the block_n indices in reverse. Returns: Updated kv_producer_state after processing the block list. """ if block_count > 0: if const_expr(not intra_wg_overlap): for offset in cutlass.range(block_count): n_block = block_indices[block_count - 1 - offset] pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block, producer_state=kv_producer_state) kv_producer_state.advance() else: n_block_first = block_indices[block_count - 1] if const_expr(not first_block_preloaded): pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_first, producer_state=kv_producer_state) for idx in cutlass.range(block_count - 1, unroll=1): n_block_prev = block_indices[block_count - 1 - idx] n_block = block_indices[block_count - 2 - idx] kv_producer_state_prev = kv_producer_state.clone() kv_producer_state.advance() pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) load_V(src_idx=n_block_prev, producer_state=kv_producer_state_prev) return kv_producer_state @cute.jit def finish_overlap_v_load( block_indices: cute.Tensor, block_count, load_V, pipeline_v, kv_producer_state, ): """Load the final V block after overlapped K/V loads.""" if block_count > 0: n_block_last = block_indices[0] pipeline_v.producer_acquire(kv_producer_state) load_V(src_idx=n_block_last, producer_state=kv_producer_state) kv_producer_state.advance() return kv_producer_state @cute.jit def sparse_tensor_m_block( m_block, qhead_per_kvhead: cutlass.Constexpr[int], q_subtile_factor: cutlass.Constexpr[int], ): """Map packed m_block indices to block-sparse tensor indices.""" block = m_block if const_expr(qhead_per_kvhead != 1): block = block // qhead_per_kvhead if const_expr(q_subtile_factor != 1): block = block // q_subtile_factor return block @cute.jit def produce_block_sparse_loads( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, kv_producer_state, load_K, load_V, pipeline_k, pipeline_v, intra_wg_overlap: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Iterate over the mask and full block lists for a single tile. Q is loaded separately on its own mbarrier before this function is called. The masked (partial) list may leave the last V load pending when intra-warp-group overlap is enabled. The first full block must consume that pending V while issuing its own K load on the next pipeline stage. In the intra-wg-overlap path, the last masked block leaves its V copy in flight while we advance the producer state to start the next full K. Either the full list overlaps that pending V load, or, if no full blocks exist, we explicitly drain it. Args: qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 if mask_empty: # No masked blocks: the full list owns the initial K load. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, intra_wg_overlap=intra_wg_overlap, ) if const_expr(intra_wg_overlap) and curr_full_block_cnt > 0: kv_producer_state = finish_overlap_v_load( curr_full_block_idx, curr_full_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: # Masked blocks present. When overlap is disabled this fully drains the list. kv_producer_state = load_block_list( curr_mask_block_idx, curr_mask_block_cnt, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, intra_wg_overlap=intra_wg_overlap, ) if full_empty: if const_expr(intra_wg_overlap): kv_producer_state = finish_overlap_v_load( curr_mask_block_idx, curr_mask_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: if const_expr(intra_wg_overlap): # Bridge the masked list to the full list by overlapping the pending masked V # with the first full K load. n_block_mask_last = curr_mask_block_idx[0] n_block_full_first = curr_full_block_idx[curr_full_block_cnt - 1] kv_producer_state_prev = kv_producer_state.clone() kv_producer_state.advance() pipeline_k.producer_acquire(kv_producer_state) load_K(src_idx=n_block_full_first, producer_state=kv_producer_state) pipeline_v.producer_acquire(kv_producer_state_prev) load_V(src_idx=n_block_mask_last, producer_state=kv_producer_state_prev) kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, first_block_preloaded=True, kv_producer_state=kv_producer_state, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, intra_wg_overlap=intra_wg_overlap, ) kv_producer_state = finish_overlap_v_load( curr_full_block_idx, curr_full_block_cnt, load_V, pipeline_v, kv_producer_state, ) else: # Non-overlap path with both lists: run the full list normally. kv_producer_state = load_block_list( curr_full_block_idx, curr_full_block_cnt, first_block_preloaded=False, kv_producer_state=kv_producer_state, load_K=load_K, load_V=load_V, pipeline_k=pipeline_k, pipeline_v=pipeline_v, intra_wg_overlap=intra_wg_overlap, ) return kv_producer_state @cute.jit def consume_block_sparse_loads( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, seqlen, kv_consumer_state, mma_pv_fn, mma_one_n_block, process_first_half_block, process_last_half_block, mask_fn, score_mod_fn, O_should_accumulate, mask_mod, fastdiv_mods, intra_wg_overlap: cutlass.Constexpr, warp_scheduler_barrier_sync: Callable, warp_scheduler_barrier_arrive: Callable, qhead_per_kvhead: cutlass.Constexpr[int] = 1, q_subtile_factor: cutlass.Constexpr[int] = 1, ): """Consume the mask and full block lists for a single tile on the consumer side. Mirrors `produce_block_sparse_loads` so that the consumer pipeline uses the same sparse tensor indexing. Args: qhead_per_kvhead: Pack-GQA factor. When > 1, m_block is in packed space and must be converted to unpacked for sparse tensor indexing. """ mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] processed_any = curr_mask_block_cnt + curr_full_block_cnt > 0 if const_expr(not intra_wg_overlap): if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial( mask_fn, mask_mod=mask_mod, mask_seqlen=True, fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, ), is_first_n_block=True, ) O_should_accumulate = True for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True if curr_full_block_cnt == 0: warp_scheduler_barrier_arrive() if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_seqlen=True), is_first_n_block=True, ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True else: kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), is_first_n_block=False, ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), is_first_n_block=False, ) O_should_accumulate = True warp_scheduler_barrier_arrive() else: if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] kv_consumer_state = process_first_half_block( n_block=mask_n_block, seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial( mask_fn, mask_mod=mask_mod, mask_seqlen=True, fastdiv_mods=fastdiv_mods if cutlass.const_expr(mask_mod is not None) else None, ), score_mod_fn=score_mod_fn, is_first_block=True, ) for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=mask_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=mask_mod, mask_seqlen=False), ) O_should_accumulate = True if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: kv_consumer_state = process_first_half_block( n_block=full_n_block, seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), score_mod_fn=score_mod_fn, is_first_block=True, ) else: kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=True), ) O_should_accumulate = True for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=full_n_block, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=None, mask_seqlen=False), ) O_should_accumulate = True if curr_mask_block_cnt + curr_full_block_cnt > 0: kv_consumer_state = process_last_half_block( kv_consumer_state=kv_consumer_state, zero_init=not O_should_accumulate, ) O_should_accumulate = True return kv_consumer_state, O_should_accumulate, processed_any @cute.jit def load_block_list_sm100( block_indices: cute.Tensor, block_count, load_q_with_first: cutlass.Constexpr, q_stage: cutlass.Constexpr, kv_producer_state, load_Q, load_K, load_V, pipeline_kv, ): """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" if block_count > 0: # First iteration: load Q alongside K if requested n_block_first = block_indices[block_count - 1] if const_expr(load_q_with_first): # SM100 loads Q0 and optionally Q1 load_Q(block=0, stage=0) if const_expr(q_stage == 2): load_Q(block=1, stage=1) # SM100 doesn't use producer_acquire for pipeline_kv in load path # The pipeline barriers are handled inside load_KV load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() # Remaining blocks for offset in cutlass.range(1, block_count): n_block = block_indices[block_count - 1 - offset] load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) kv_producer_state.advance() return kv_producer_state # SM100-specific tile processor using SM100 helpers @cute.jit def produce_block_sparse_loads_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, kv_producer_state, load_Q, load_K, load_V, pipeline_kv, q_stage: cutlass.Constexpr, q_producer_phase: Int32, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, ): """SM100 entry point for sparse block iteration. SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use simplified block processing that just calls producer_acquire without extras. Args: m_block: which tile of m we are processing qhead_per_kvhead: Constexpr pack factor """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 q_phase_flipped = False if mask_empty: # No masked blocks: process full list with Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) q_phase_flipped = not full_empty else: # Process masked blocks with Q loading kv_producer_state = load_block_list_sm100( curr_mask_block_idx, curr_mask_block_cnt, load_q_with_first=True, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) q_phase_flipped = True if not full_empty: # Process full blocks without Q loading kv_producer_state = load_block_list_sm100( curr_full_block_idx, curr_full_block_cnt, load_q_with_first=False, q_stage=q_stage, kv_producer_state=kv_producer_state, load_Q=load_Q, load_K=load_K, load_V=load_V, pipeline_kv=pipeline_kv, ) if q_phase_flipped: q_producer_phase ^= 1 return kv_producer_state, q_producer_phase @cute.jit def get_total_block_count( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr, ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( mask_block_cnt[batch_idx, head_idx, m_block_sparse] + full_block_cnt[batch_idx, head_idx, m_block_sparse] ) else: return mask_block_cnt[batch_idx, head_idx, m_block_sparse] @cute.jit def handle_block_sparse_empty_tile_correction_sm100( tidx: Int32, q_stage: cutlass.Constexpr, m_block_size: cutlass.Constexpr, qhead_per_kvhead, pack_gqa: cutlass.Constexpr, is_split_kv: cutlass.Constexpr, learnable_sink, mLSE, seqlen, m_block: Int32, head_idx: Int32, batch_idx: Int32, split_idx: Int32, sScale: cute.Tensor, stats: list, correction_epilogue: Callable, thr_mma_pv: cute.core.ThrMma, tOtO: cute.Tensor, sO: cute.Tensor, pipeline_sm_stats: cutlass.pipeline.PipelineAsync, sm_stats_barrier: cutlass.pipeline.NamedBarrier, pipeline_o_epi: cutlass.pipeline.PipelineAsync, sm_stats_consumer_phase: Int32, o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Handle SM100 forward block-sparse tiles with no active KV blocks. This path is taken when `total_block_cnt == 0`. The softmax warp-group still arrives `mbar_softmax_corr_full` (synthetic "no work") so the correction warp-group can: - seed fully-masked-row stats (row_sum=1; row_max=-inf when tracked) for LSE - run `correction_epilogue` with `scale=0` so the output tile is written as zeros (independent of any prior tmem contents) - wait on `mbar_softmax_corr_full` and arrive `mbar_softmax_corr_empty` (and `mbar_corr_epi_*` when applicable) so phases stay aligned across tiles This helper intentionally does not touch `mbar_P_full_*` since no P is produced. See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. """ LOG2_E = Float32(math.log2(math.e)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 for stage in cutlass.range_constexpr(q_stage): row_sum_value = Float32(1.0) row_max_value = ( -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None ) if const_expr(learnable_sink is not None): sink_val = -Float32.inf if const_expr(not pack_gqa): sink_val = Float32(learnable_sink[head_idx]) elif tidx < m_block_size: q_head_idx = ( (q_stage * m_block + stage) * m_block_size + tidx ) % qhead_per_kvhead + head_idx * qhead_per_kvhead sink_val = Float32(learnable_sink[q_head_idx]) if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): if row_max_value == -Float32.inf: row_max_value = sink_val * (LOG2_E / softmax_scale_log2) row_sum_value = Float32(1.0) else: row_sum_value = row_sum_value + cute.math.exp2( sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True ) if tidx < m_block_size: scale_row_idx = tidx + stage * m_block_size sScale[scale_row_idx] = row_sum_value if const_expr(mLSE is not None or learnable_sink is not None): sScale[scale_row_idx + q_stage * m_block_size] = row_max_value acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value stats[stage] = (row_sum_value, row_max_value, acc_flag) # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) pipeline_sm_stats.consumer_release_w_index(stage) if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], tidx, stage, m_block, seqlen.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], mO_cur, gO[None, None, stage], gmem_tiled_copy_O, ) if const_expr(gmem_tiled_copy_O is None): pipeline_o_epi.producer_commit_w_index(stage) sm_stats_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 return ( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, ) @cute.jit def softmax_block_sparse_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, m_block, softmax_step: Callable, mask_fn: Callable, mask_fn_none: Callable, mma_si_consumer_phase: Int32, si_corr_producer_phase: Int32, s0_s1_sequence_phase: Int32, pipeline_sm_stats: cutlass.pipeline.PipelineAsync, sm_stats_barrier: cutlass.pipeline.NamedBarrier, q_stage: cutlass.Constexpr, stage_idx: Int32, check_m_boundary: bool, qhead_per_kvhead: cutlass.Constexpr, q_subtile_factor: cutlass.Constexpr[int] = 1, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] if const_expr(full_block_cnt is not None): curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block_sparse] curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block_sparse, None] else: curr_full_block_cnt = Int32(0) curr_full_block_idx = None total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. # pipeline_sm_stats.producer_commit_w_index(stage_idx) sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, mask_n_block, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True, check_q_boundary=check_m_boundary), ) for i in cutlass.range(1, curr_mask_block_cnt): mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, mask_n_block, mask_fn=partial(mask_fn, mask_seqlen=False, check_q_boundary=check_m_boundary), ) if curr_full_block_cnt > 0: full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] if curr_mask_block_cnt == 0: ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, is_first=True, mask_fn=partial( mask_fn_none, mask_seqlen=True, check_q_boundary=check_m_boundary ), ) else: ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, is_first=False, mask_fn=partial( mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary ), ) for i in cutlass.range(1, curr_full_block_cnt): full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, ) = softmax_step( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, full_n_block, mask_fn=partial( mask_fn_none, mask_seqlen=False, check_q_boundary=check_m_boundary ), ) return ( mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, total_block_cnt == 0, ) # ============================================================================= # Backward-specific block-sparse helpers (SM100) # ============================================================================= # # In backward, iteration is transposed compared to forward: # - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles) # - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles) # # The backward block-sparse tensors use "Q direction" indexing: # - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile # - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process # @cute.jit def get_total_q_block_count_bwd( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Count total tile iterations for given n_block (KV tile) in backward.""" q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors total = q_block_cnt[batch_idx, head_idx, n_block] if const_expr(full_block_cnt is not None): total = total + full_block_cnt[batch_idx, head_idx, n_block] return total * subtile_factor @cute.jit def produce_block_sparse_q_loads_bwd_sm100( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, # Pipeline states (will be returned after advancing) producer_state_Q_LSE, producer_state_dO_dPsum, # Pipelines pipeline_Q, pipeline_LSE, pipeline_dO, pipeline_dPsum, # Load functions load_K, load_V, load_Q, load_dO, copy_stats, # Global tensors for LSE/dPsum gLSE, sLSE, gdPsum, sdPsum, # TMA copy bytes for extra_tx_count tma_copy_bytes_K, tma_copy_bytes_V, # Flags for which loads to perform should_load_Q: cutlass.Constexpr, should_load_dO: cutlass.Constexpr, # Subtiling factor and bounds subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """SM100 backward block sparse loading with subtiling. Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. """ ( curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, loop_count, ) = get_block_sparse_iteration_info_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max ) for iter_idx in cutlass.range(loop_count, unroll=1): m_block, _ = get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, subtile_factor, m_block_max, ) m_block_safe = m_block if m_block_max > 0: m_block_safe = cutlass.min(m_block, m_block_max - 1) if iter_idx == 0: # First block: load K/V alongside Q/dO if const_expr(should_load_Q): pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) load_Q(m_block_safe, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block_safe], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire( producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V ) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block_safe], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() else: # Subsequent blocks: just load Q/dO (K/V already loaded) if const_expr(should_load_Q): pipeline_Q.producer_acquire(producer_state_Q_LSE) load_Q(m_block_safe, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block_safe], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire(producer_state_dO_dPsum) load_dO(m_block_safe, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block_safe], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), ) producer_state_dO_dPsum.advance() return producer_state_Q_LSE, producer_state_dO_dPsum @cute.jit def get_block_sparse_iteration_info_bwd( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Extract block-sparse iteration info for backward pass. Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None sparse_block_count = curr_q_cnt if const_expr(full_cnt is not None): sparse_block_count = sparse_block_count + curr_full_cnt total_count = sparse_block_count * subtile_factor return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count @cute.jit def get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx: cute.Tensor, curr_full_cnt, curr_full_idx: Optional[cute.Tensor], subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, ): """Derive m_block index and is_full_block flag from iteration index. Returns (m_block, is_full_block): - m_block: The actual Q-tile block index - is_full_block: True if this is a full block (no mask_mod needed) """ sparse_iter_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor sparse_m_block = Int32(0) is_full_block = False if const_expr(curr_full_idx is not None): if sparse_iter_idx < curr_q_cnt: sparse_m_block = curr_q_idx[sparse_iter_idx] else: sparse_m_block = curr_full_idx[sparse_iter_idx - curr_q_cnt] is_full_block = True else: sparse_m_block = curr_q_idx[sparse_iter_idx] return sparse_m_block * subtile_factor + subtile_offset, is_full_block @cute.jit def _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage: cutlass.Constexpr, load_kv: bool, ): """Load one Q/dO block, optionally loading K/V on first iteration.""" if load_kv: pipeline_Q.producer_acquire(producer_state_Q, extra_tx_count=tma_copy_bytes_K) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) else: pipeline_Q.producer_acquire(producer_state_Q) load_Q(m_block, producer_state=producer_state_Q) load_LSE(m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q ) if load_kv: pipeline_dO.producer_acquire(producer_state_dO_cur, extra_tx_count=tma_copy_bytes_V) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) else: pipeline_dO.producer_acquire(producer_state_dO_cur) load_dO(m_block, producer_state=producer_state_dO_cur) load_dPsum(m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() producer_state_dO.advance() return producer_state_Q, producer_state_dO @cute.jit def produce_block_sparse_q_loads_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage: cutlass.Constexpr, subtile_factor: cutlass.Constexpr, m_block_max: int, ): """SM90 backward block sparse loading with separate partial/full loops. K/V are loaded with the first valid block. Iterates partial blocks first, then full blocks, matching consumer order. Returns updated (producer_state_Q, producer_state_dO). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None kv_loaded = False for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: producer_state_Q, producer_state_dO = _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage, load_kv=not kv_loaded, ) kv_loaded = True if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: producer_state_Q, producer_state_dO = _load_q_do_block_sm90( m_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, tma_copy_bytes_K, tma_copy_bytes_V, Q_stage_eq_dO_stage, load_kv=not kv_loaded, ) kv_loaded = True return producer_state_Q, producer_state_dO @cute.jit def consume_block_sparse_mma_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, consumer_state_Q, consumer_state_dO, mma_one_m_block_fn, mask, mask_mod, is_causal: cutlass.Constexpr, is_local: cutlass.Constexpr, thr_mma_SdP, score_mod_fn=None, score_mod_bwd_fn=None, subtile_factor: cutlass.Constexpr = 1, m_block_max: int = 0, aux_tensors=None, fastdiv_mods=(None, None), ): """SM90 backward block sparse MMA consumption with separate partial/full loops. Partial blocks are processed first (with mask_mod applied), then full blocks (without mask_mod). This ensures mask_mod is only applied where needed. Returns updated (consumer_state_Q, consumer_state_dO). """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None dKV_accumulate = False mask_fn_partial = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, mask_causal=is_causal, mask_local=is_local, mask_mod=mask_mod, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mask_fn_full = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, mask_causal=is_causal, mask_local=is_local, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn_partial, score_mod_fn=score_mod_fn, score_mod_bwd_fn=score_mod_bwd_fn, dKV_accumulate=dKV_accumulate, ) dKV_accumulate = True if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: consumer_state_Q, consumer_state_dO = mma_one_m_block_fn( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn_full, score_mod_fn=score_mod_fn, score_mod_bwd_fn=score_mod_bwd_fn, dKV_accumulate=dKV_accumulate, ) dKV_accumulate = True return consumer_state_Q, consumer_state_dO @cute.jit def _store_one_dQaccum_sm90( m_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, num_mma_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """Store dQaccum for a single m_block.""" for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, gdQaccum[(None, warp_group_idx), m_block].iterator, tma_copy_bytes_dQ, ) cute.arch.cp_async_bulk_commit_group() @cute.jit def dQaccum_store_block_sparse_bwd_sm90( blocksparse_tensors: BlockSparseTensors, batch_idx, head_idx, n_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, subtile_factor: cutlass.Constexpr, m_block_max: int, num_mma_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """SM90 backward block sparse dQaccum store with separate partial/full loops. Iterates partial blocks first, then full blocks, matching producer/consumer order. """ q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] if const_expr(full_cnt is not None): curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] else: curr_full_cnt = Int32(0) curr_full_idx = None for iter_idx in cutlass.range(curr_q_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_q_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: _store_one_dQaccum_sm90( m_block, sdQaccum, gdQaccum, num_mma_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) if const_expr(full_cnt is not None): for iter_idx in cutlass.range(curr_full_cnt * subtile_factor, unroll=1): sparse_idx = iter_idx // subtile_factor subtile_offset = iter_idx % subtile_factor m_block = curr_full_idx[sparse_idx] * subtile_factor + subtile_offset if m_block < m_block_max: _store_one_dQaccum_sm90( m_block, sdQaccum, gdQaccum, num_mma_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) ================================================ FILE: flash_attn/cute/block_sparsity.py ================================================ """ Block-sparsity utilities for FlexAttention """ from typing import Callable, NamedTuple, Tuple import cutlass.cute as cute import torch from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: return (a + b - 1) // b class BlockSparseTensors(NamedTuple): mask_block_cnt: cute.Tensor mask_block_idx: cute.Tensor full_block_cnt: cute.Tensor | None full_block_idx: cute.Tensor | None def __new_from_mlir_values__(self, values): if len(values) == 2: values = (*values, None, None) return BlockSparseTensors(*values) class BlockSparseTensorsTorch(NamedTuple): mask_block_cnt: torch.Tensor mask_block_idx: torch.Tensor full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None block_size: tuple[int, int] | None = None def _expand_sparsity_tensor( tensor: torch.Tensor, expected_shape: Tuple[int, ...], tensor_name: str, context: str | None, hint: str | Callable[[], str] | None, ) -> torch.Tensor: """Check if we need to expand the tensor to expected shape, and do so if possible.""" needs_expand = tensor.shape != expected_shape if not needs_expand: return tensor can_expand = all(map(lambda cur, tgt: cur == tgt or cur == 1, tensor.shape, expected_shape)) if not can_expand: context_clause = f" ({context})" if context else "" resolved_hint = hint() if callable(hint) else hint hint_clause = f" Hint: {resolved_hint}" if resolved_hint else "" raise ValueError( f"{tensor_name}{context_clause} with shape {tensor.shape} cannot be expanded to expected shape {expected_shape}." f"{hint_clause}" ) return tensor.expand(*expected_shape) def _check_and_expand_block( name: str, cnt: torch.Tensor | None, idx: torch.Tensor | None, expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], context: str | None, hint: str | Callable[[], str] | None, ) -> Tuple[torch.Tensor | None, torch.Tensor | None]: if (cnt is None) != (idx is None): raise ValueError( f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" ) if cnt is None or idx is None: return None, None if cnt.dtype != torch.int32 or idx.dtype != torch.int32: raise ValueError(f"{name}_block tensors must have dtype torch.int32") if cnt.device != idx.device: raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") if not cnt.is_cuda or not idx.is_cuda: raise ValueError(f"{name}_block tensors must live on CUDA") expanded_cnt = _expand_sparsity_tensor( cnt, expected_count_shape, f"{name}_block_cnt", context, hint ) expanded_idx = _expand_sparsity_tensor( idx, expected_index_shape, f"{name}_block_idx", context, hint ) return expanded_cnt, expanded_idx def get_block_sparse_expected_shapes( batch_size: int, num_head: int, seqlen_q: int, seqlen_k: int, m_block_size: int, n_block_size: int, q_stage: int, ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: """Return (expected_count_shape, expected_index_shape) for block sparse normalization.""" m_block_size_effective = q_stage * m_block_size expected_m_blocks = ceildiv(seqlen_q, m_block_size_effective) expected_n_blocks = ceildiv(seqlen_k, n_block_size) expected_count_shape = (batch_size, num_head, expected_m_blocks) expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks) return expected_count_shape, expected_index_shape def infer_block_sparse_expected_shapes( tensors: BlockSparseTensorsTorch, *, batch_size: int, num_head: int, seqlen_q: int, seqlen_k: int, m_block_size: int, n_block_size: int, q_stage: int, context: str, sparse_block_size_q: int | None = None, sparse_block_size_kv: int | None = None, ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int], int]: """Infer shapes and scaling for block-sparse tensors. Expectations: - mask_block_cnt is (B, H, M) and mask_block_idx is (B, H, M, N). - Batch/head dims may be 1 for broadcast, or match the requested sizes. - sparse_block_size_kv must match tile_n. - sparse_block_size_q must be a multiple of q_stage * tile_m. - If sparse_block_size_q is omitted and seqlen_q/num_m_blocks is ambiguous, the caller must provide block_size to disambiguate. TODO will make this required in a future PR. """ base_m_block = q_stage * m_block_size base_n_block = n_block_size if sparse_block_size_kv is None: sparse_block_size_kv = base_n_block if sparse_block_size_kv != base_n_block: raise ValueError(f"Block sparse tensors{context} require BLOCK_SIZE_KV={base_n_block}.") if tensors.mask_block_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") num_m_blocks = tensors.mask_block_idx.shape[2] if sparse_block_size_q is None: min_block_size = ceildiv(seqlen_q, num_m_blocks) if num_m_blocks == 1: max_block_size = seqlen_q else: max_block_size = (seqlen_q - 1) // (num_m_blocks - 1) if max_block_size != min_block_size and base_m_block != 1: raise ValueError( f"Block sparse tensors{context} require explicit sparse_block_size[0] " f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}." ) sparse_block_size_q = min_block_size if sparse_block_size_q % base_m_block != 0: raise ValueError( f"Block sparse tensors{context} have block size {sparse_block_size_q}, " f"which must be a multiple of {base_m_block}." ) expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q) expected_n_blocks = ceildiv(seqlen_k, sparse_block_size_kv) q_subtile_factor = sparse_block_size_q // base_m_block expected_count_shape = (batch_size, num_head, expected_m_blocks) expected_index_shape = (batch_size, num_head, expected_m_blocks, expected_n_blocks) mask_block_cnt = tensors.mask_block_cnt mask_block_idx = tensors.mask_block_idx if mask_block_cnt is None or mask_block_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") if mask_block_cnt.ndim != 3 or mask_block_idx.ndim != 4: raise ValueError( f"Block sparse tensors{context} must have shapes (B, H, M) and (B, H, M, N)." ) for dim_name, cur, tgt in ( ("batch", mask_block_cnt.shape[0], expected_count_shape[0]), ("head", mask_block_cnt.shape[1], expected_count_shape[1]), ): if cur != tgt and cur != 1: raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") for dim_name, cur, tgt in ( ("batch", mask_block_idx.shape[0], expected_index_shape[0]), ("head", mask_block_idx.shape[1], expected_index_shape[1]), ): if cur != tgt and cur != 1: raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") if mask_block_cnt.shape[2] != mask_block_idx.shape[2]: raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.") if mask_block_idx.shape[3] != expected_n_blocks: raise ValueError( f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}." ) if expected_m_blocks != num_m_blocks: raise ValueError( f"Block sparse tensors{context} m-block dimension {num_m_blocks} does not match " f"sparse_block_size_q={sparse_block_size_q}. " f"Set BlockSparseTensorsTorch.block_size to match the BlockMask BLOCK_SIZE." ) return expected_count_shape, expected_index_shape, q_subtile_factor def get_block_sparse_expected_shapes_bwd( batch_size: int, num_head: int, seqlen_q: int, seqlen_k: int, m_block_size: int, n_block_size: int, subtile_factor: int, ) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: """Return (expected_count_shape, expected_index_shape) for backward block sparse normalization. Backward uses Q-direction indexing (transposed from forward), where shapes are indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined by subtile_factor * m_block_size. """ sparse_block_size_q = subtile_factor * m_block_size expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q) expected_n_blocks = ceildiv(seqlen_k, n_block_size) expected_count_shape = (batch_size, num_head, expected_n_blocks) expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks) return expected_count_shape, expected_index_shape def normalize_block_sparse_tensors( tensors: BlockSparseTensorsTorch, *, expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], context: str | None = None, hint: str | Callable[[], str] | None = None, ) -> BlockSparseTensorsTorch: if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") mask_cnt, mask_idx = _check_and_expand_block( "mask", tensors.mask_block_cnt, tensors.mask_block_idx, expected_count_shape, expected_index_shape, context, hint, ) if mask_cnt is None or mask_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") full_cnt, full_idx = _check_and_expand_block( "full", tensors.full_block_cnt, tensors.full_block_idx, expected_count_shape, expected_index_shape, context, hint, ) if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") return BlockSparseTensorsTorch( mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, block_size=tensors.block_size, ) def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) def get_block_sparse_broadcast_pattern( tensors: BlockSparseTensorsTorch, ) -> Tuple[Tuple[bool, ...], ...] | None: """Return broadcast pattern for block sparse tensors by checking actual strides. Returns a tuple of broadcast patterns (one per tensor) where each pattern is a tuple of bools indicating which dims have stride=0. This is used in compile keys to ensure kernels are recompiled when broadcast patterns change, since CuTe's mark_layout_dynamic() keeps stride=0 as static. The tensors should already be expanded/normalized before calling this function. Returns None if block sparsity is not enabled. """ if not is_block_sparsity_enabled(tensors): return None patterns = [] for tensor in ( tensors.mask_block_cnt, tensors.mask_block_idx, tensors.full_block_cnt, tensors.full_block_idx, ): if tensor is not None: patterns.append(get_broadcast_dims(tensor)) else: patterns.append(None) return tuple(patterns) def normalize_block_sparse_config( tensors: BlockSparseTensorsTorch, *, batch_size: int, num_head: int, seqlen_q: int, seqlen_k: int, block_size: tuple[int, int], q_stage: int, ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: m_block_size, n_block_size = block_size if tensors.block_size is None: sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size else: sparse_block_size_q, sparse_block_size_kv = tensors.block_size if sparse_block_size_kv != n_block_size: raise ValueError( f"Block sparsity requires sparse_block_size[1]={n_block_size} to match tile_n." ) expected_count_shape, expected_index_shape, q_subtile_factor = ( infer_block_sparse_expected_shapes( tensors, batch_size=batch_size, num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, m_block_size=m_block_size, n_block_size=n_block_size, q_stage=q_stage, context="forward", sparse_block_size_q=sparse_block_size_q, sparse_block_size_kv=sparse_block_size_kv, ) ) normalized_tensors = normalize_block_sparse_tensors( tensors, expected_count_shape=expected_count_shape, expected_index_shape=expected_index_shape, ) return ( normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors), q_subtile_factor, ) def normalize_block_sparse_config_bwd( tensors: BlockSparseTensorsTorch, *, batch_size: int, num_head: int, seqlen_q: int, seqlen_k: int, block_size: tuple[int, int], subtile_factor: int, ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None]: m_block_size, n_block_size = block_size if tensors.block_size is None: sparse_block_size_q, sparse_block_size_kv = subtile_factor * m_block_size, n_block_size else: sparse_block_size_q, sparse_block_size_kv = tensors.block_size if sparse_block_size_q != subtile_factor * m_block_size: raise ValueError( f"Block sparsity expects sparse_block_size_q={subtile_factor * m_block_size} " f"for subtile_factor={subtile_factor}." ) if sparse_block_size_kv != n_block_size: raise ValueError( f"Block sparsity expects sparse_block_size[1]={n_block_size} to match tile_n." ) expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( batch_size, num_head, seqlen_q, seqlen_k, m_block_size, n_block_size, subtile_factor, ) normalized_tensors = normalize_block_sparse_tensors( tensors, expected_count_shape=expected_count_shape, expected_index_shape=expected_index_shape, context="_flash_attn_bwd", hint=lambda: ( f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, " f"and optionally full_q_cnt/full_q_idx). Regenerate the backward BlockMask with " f"BLOCK_SIZE=({subtile_factor * m_block_size}, {n_block_size})." ), ) return normalized_tensors, get_block_sparse_broadcast_pattern(normalized_tensors) def to_cute_block_sparse_tensors( tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True ) -> BlockSparseTensors | None: """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None ( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_, ) = tensors ( mask_block_cnt_tensor, mask_block_idx_tensor, ) = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) for t in (mask_block_cnt, mask_block_idx) ] ( full_block_cnt_tensor, full_block_idx_tensor, ) = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None else None for t in (full_block_cnt, full_block_idx) ] return BlockSparseTensors( mask_block_cnt_tensor, mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, ) def fast_sampling(mask_mod): """Convenience decorator to mark mask_mod as safe for 5-point fast sampling""" mask_mod.use_fast_sampling = True return mask_mod ================================================ FILE: flash_attn/cute/cache_utils.py ================================================ # Manage Ahead-of-Time (AOT) compiled kernels import fcntl import hashlib import logging import os import pickle import sys import tempfile import time from functools import lru_cache from getpass import getuser from pathlib import Path from typing import Hashable, TypeAlias import ctypes import cutlass import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. # Upstream cute.runtime.load_module loads these without RTLD_GLOBAL, which causes # "undefined symbol" errors when loading cached kernels from disk. for _lib_path in cute.runtime.find_runtime_libraries(enable_tvm_ffi=False): if Path(_lib_path).exists(): ctypes.CDLL(_lib_path, mode=ctypes.RTLD_GLOBAL) CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function logger = logging.getLogger(__name__) _handler = logging.StreamHandler() _handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) logger.addHandler(_handler) logger.setLevel(logging.DEBUG) # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" # Customize cache dir via `FLASH_ATTENTION_CUTE_DSL_CACHE_DIR`, default is # `/tmp/${USER}/flash_attention_cute_dsl_cache`` CUTE_DSL_CACHE_DIR: str | None = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_DIR", None) def get_cache_path() -> Path: if CUTE_DSL_CACHE_DIR is not None: cache_dir = Path(CUTE_DSL_CACHE_DIR) else: cache_dir = Path(tempfile.gettempdir()) / getuser() / "flash_attention_cute_dsl_cache" cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir @lru_cache(maxsize=1) def _compute_source_fingerprint() -> str: """ Hash all CuTe Python sources plus runtime ABI stamps into a short fingerprint. The fingerprint changes whenever: - Any .py file under flash_attn/cute is added, removed, renamed, or modified. - The Python minor version changes (e.g. 3.13 -> 3.14). - The cutlass or tvm_ffi package version changes. Computed once per process and cached. """ cute_root = Path(__file__).resolve().parent h = hashlib.sha256() h.update(f"py{sys.version_info.major}.{sys.version_info.minor}".encode()) h.update(f"cutlass={cutlass.__version__}".encode()) h.update(f"tvm_ffi={tvm_ffi.__version__}".encode()) for src in sorted(cute_root.rglob("*.py")): if not src.is_file(): continue h.update(src.relative_to(cute_root).as_posix().encode()) content = src.read_bytes() h.update(len(content).to_bytes(8, "little")) h.update(content) return h.hexdigest() class FileLock: """Context manager for advisory file locks using fcntl.flock. Supports exclusive (write) and shared (read) locks. Always blocks with polling until the lock is acquired or timeout is reached. Usage: with FileLock(lock_path, exclusive=True, timeout=15, label="abc"): # do work under lock """ def __init__( self, lock_path: Path, exclusive: bool, timeout: float = 15, label: str = "", ): """ Args: lock_path: Path to the lock file on disk. exclusive: True for exclusive (write) lock, False for shared (read) lock. timeout: Max seconds to wait for lock acquisition before raising RuntimeError. label: Optional human-readable label for error messages. """ self.lock_path: Path = lock_path self.exclusive: bool = exclusive self.timeout: float = timeout self.label: str = label self._fd: int = -1 @property def _lock_label(self) -> str: kind = "exclusive" if self.exclusive else "shared" return f"{kind} {self.label}" if self.label else kind def __enter__(self) -> "FileLock": open_flags = os.O_WRONLY | os.O_CREAT if self.exclusive else os.O_RDONLY | os.O_CREAT lock_type = fcntl.LOCK_EX if self.exclusive else fcntl.LOCK_SH self._fd = os.open(str(self.lock_path), open_flags) deadline = time.monotonic() + self.timeout acquired = False while time.monotonic() < deadline: try: fcntl.flock(self._fd, lock_type | fcntl.LOCK_NB) acquired = True break except OSError: time.sleep(0.1) if not acquired: os.close(self._fd) self._fd = None raise RuntimeError( f"Timed out after {self.timeout}s waiting for " f"{self._lock_label} lock: {self.lock_path}" ) return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self._fd is not None: fcntl.flock(self._fd, fcntl.LOCK_UN) os.close(self._fd) self._fd = None class JITCache: """ In-memory cache for compiled functions. """ def __init__(self): self.cache: dict[CompileKeyType, CallableFunction] = {} def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: self.cache[key] = fn def __getitem__(self, key: CompileKeyType) -> CallableFunction: return self.cache[key] def __contains__(self, key: CompileKeyType) -> bool: return key in self.cache def clear(self) -> None: """ Clear in-memory cache of compiled functions """ self.cache.clear() class JITPersistentCache(JITCache): """ In-memory cache for compiled functions, which is also backed by persistent storage. Use cutedsl ahead-of-time (AOT) compilation, only supporting enable_tvm_ffi=True """ EXPORT_FUNCTION_PREFIX = "func" LOCK_TIMEOUT_SECONDS = 15 def __init__(self, cache_path: Path): super().__init__() cache_path.mkdir(parents=True, exist_ok=True) self.cache_path: Path = cache_path def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: JITCache.__setitem__(self, key, fn) self._try_export_to_storage(key, fn) def __getitem__(self, key: CompileKeyType) -> CallableFunction: # Use __contains__ to try populating in-memory cache with persistent storage self.__contains__(key) return JITCache.__getitem__(self, key) def __contains__(self, key: CompileKeyType) -> bool: # Checks in-memory cache first, then tries loading from storage. # When returning True, guarantees the in-memory cache is populated. if JITCache.__contains__(self, key): return True return self._try_load_from_storage(key) def _try_load_from_storage(self, key: CompileKeyType) -> bool: """ Try to load a function from persistent storage into in-memory cache. Returns True if loaded successfully, False if not found on disk. Holds a shared lock during loading to prevent concurrent writes. """ sha256_hex = self._key_to_hash(key) obj_path = self.cache_path / f"{sha256_hex}.o" with FileLock( self._lock_path(sha256_hex), exclusive=False, timeout=self.LOCK_TIMEOUT_SECONDS, label=sha256_hex, ): if obj_path.exists(): logger.debug("Loading compiled function from disk: %s", obj_path) m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) JITCache.__setitem__(self, key, fn) return True else: logger.debug("Cache miss on disk for key hash %s", sha256_hex) return False def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: """Export a compiled function to persistent storage under exclusive lock.""" sha256_hex = self._key_to_hash(key) with FileLock( self._lock_path(sha256_hex), exclusive=True, timeout=self.LOCK_TIMEOUT_SECONDS, label=sha256_hex, ): obj_path = self.cache_path / f"{sha256_hex}.o" if obj_path.exists(): # Another process already exported. logger.debug("Skipping export, already on disk: %s", obj_path) return logger.debug("Exporting compiled function to disk: %s", obj_path) fn.export_to_c( object_file_path=str(obj_path), function_name=self.EXPORT_FUNCTION_PREFIX, ) logger.debug("Successfully exported compiled function to disk: %s", obj_path) def _key_to_hash(self, key: CompileKeyType) -> str: return hashlib.sha256(pickle.dumps(key)).hexdigest() def _lock_path(self, sha256_hex: str) -> Path: return self.cache_path / f"{sha256_hex}.lock" def clear(self) -> None: """ Not only clear the in-memory cache. Also purge persistent compilation cache. """ logger.debug("Clearing persistent cache at %s", self.cache_path) super().clear() for child in self.cache_path.iterdir(): child.unlink() def get_jit_cache(name: str | None = None) -> JITCache: """ JIT cache factory. `name` is an optional identifier to create subdirectories to manage cache. When persistent caching is enabled, artifacts are namespaced under a source fingerprint directory so that code or dependency changes automatically invalidate stale entries. """ if CUTE_DSL_CACHE_ENABLED: path = get_cache_path() / _compute_source_fingerprint() if name: path = path / name logger.debug("Creating persistent JIT cache at %s", path) return JITPersistentCache(path) else: logger.debug("Persistent cache disabled, using in-memory JIT cache") return JITCache() ================================================ FILE: flash_attn/cute/compute_block_sparsity.py ================================================ from functools import partial from typing import Callable, Optional, Tuple import cutlass import cutlass.cute as cute import torch from cutlass import Boolean, Int8, Int32, const_expr from flash_attn.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar from flash_attn.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: """Block sparsity kernel for FlexAttention. This kernel computes `mask_mod` for every token of each block to determine if an n block is full, masked, or neither. Writes block counts and indices to a BlockSparseTensors object. When use_fast_sampling=True, uses 5-point sampling (4 corners + center) which is much faster but only suitable for masks where this is sufficient. TODO: - optimize mask_mod evaluation - varlen support - transposed tensors for bwd pass """ def __init__( self, mask_mod: Callable, tile_mn: Tuple[int, int], compute_full_blocks: bool = True, use_aux_tensors: bool = False, use_fast_sampling: bool = False, ): self.mask_mod = mask_mod self.tile_mn = tile_mn self.compute_full_blocks = compute_full_blocks self.use_aux_tensors = use_aux_tensors self.use_fast_sampling = use_fast_sampling @cute.jit def __call__( self, blocksparse_tensors: BlockSparseTensors, seqlen_q: Int32, seqlen_k: Int32, aux_tensors: Optional[list] = None, ): self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors if const_expr(self.compute_full_blocks): assert self.full_cnt is not None and self.full_idx is not None, ( "full block tensors must be provided when computing full blocks" ) batch_size, num_heads, num_m_blocks, num_n_blocks = self.mask_idx.shape # launch 1 CTA per m block grid = [num_m_blocks, num_heads, batch_size] if const_expr(self.use_fast_sampling): num_threads = 5 self.num_warps = 1 else: num_threads = self.tile_mn[0] self.num_warps = (num_threads + 32 - 1) // 32 self.kernel( self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, num_n_blocks, seqlen_q, seqlen_k, aux_tensors, ).launch(grid=grid, block=[num_threads, 1, 1]) @cute.kernel def kernel( self, mask_cnt: cute.Tensor, mask_idx: cute.Tensor, full_cnt: cute.Tensor, full_idx: cute.Tensor, num_n_blocks: Int32, seqlen_q: Int32, seqlen_k: Int32, aux_tensors: Optional[list] = None, ): tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.warp_idx() lane_id = cute.arch.lane_idx() m_block, head_idx, batch_idx = cute.arch.block_idx() ssa = partial(scalar_to_ssa, dtype=Int32) seqlen = SeqlenInfoQK.create( batch_idx, seqlen_q, seqlen_k, mCuSeqlensQ=None, mCuSeqlensK=None, mSeqUsedQ=None, mSeqUsedK=None, ) @cute.struct class SharedStorage: reduction_buffer_smem: cute.struct.Align[ cute.struct.MemRange[cutlass.Int8, 2 * self.num_warps], 1024 ] smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage, 16) reduction_buffer = storage.reduction_buffer_smem.get_tensor( cute.make_layout((self.num_warps, 2)) ) num_mask_blocks = Int32(0) num_full_blocks = Int32(0) for n_block in cutlass.range(num_n_blocks, unroll_full=True): m_base = m_block * self.tile_mn[0] n_base = n_block * self.tile_mn[1] if const_expr(self.use_fast_sampling): # Fast path: 5-point sampling (4 corners + center) # Clamps OOB indices to nearest in bounds. thread_result = Boolean(False) thread_is_valid = Boolean(False) q_idx = Int32(0) kv_idx = Int32(0) if tidx == 0: # Top-left corner (0, 0); always in bounds q_idx = m_base kv_idx = n_base elif tidx == 1: # Top-right corner q_idx = m_base kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) elif tidx == 2: # Bottom-left corner q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) kv_idx = n_base elif tidx == 3: # Bottom-right corner q_idx = cutlass.min(m_base + self.tile_mn[0] - 1, seqlen_q - 1) kv_idx = cutlass.min(n_base + self.tile_mn[1] - 1, seqlen_k - 1) elif tidx == 4: # Center point q_idx = m_base + (cutlass.min(seqlen_q - m_base, self.tile_mn[0])) // 2 kv_idx = n_base + (cutlass.min(seqlen_k - n_base, self.tile_mn[1])) // 2 else: thread_is_valid = Boolean(False) # Check bounds and determine if this thread has a valid index pair if tidx < 5 and q_idx < seqlen_q and kv_idx < seqlen_k: thread_is_valid = Boolean(True) q_idx_ssa = ssa(q_idx) kv_idx_ssa = ssa(kv_idx) thread_result = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), q_idx_ssa, kv_idx_ssa, seqlen, aux_tensors, ) ) else: thread_is_valid = Boolean(False) # Use vote_any_sync to see if any valid thread found unmasked or masked # Only count results from threads that checked valid indices has_unmasked = cute.arch.vote_any_sync(thread_result & thread_is_valid) has_masked = cute.arch.vote_any_sync((Boolean(not thread_result)) & thread_is_valid) else: # Full path: check all elements in the block # Track if this thread's row has any masked or unmasked elements thread_has_unmasked = Boolean(False) thread_has_masked = Boolean(False) thread_is_valid = Boolean(False) # Each thread handles 1 row q_idx = m_base + tidx kv_idx = Int32(0) if tidx < self.tile_mn[0] and q_idx < seqlen_q: thread_is_valid = Boolean(True) q_idx_ssa = ssa(q_idx) # Loop over all columns in this row for c in cutlass.range(self.tile_mn[1], unroll_full=True): kv_idx = n_base + c kv_idx_ssa = ssa(kv_idx) # Only check elements within valid sequence bounds if kv_idx < seqlen_k: # Direct scalar call mask_val = ssa_to_scalar( self.mask_mod( ssa(batch_idx), ssa(head_idx), q_idx_ssa, kv_idx_ssa, seqlen, aux_tensors, ) ) # Update tracking flags if mask_val: thread_has_unmasked = Boolean(True) else: thread_has_masked = Boolean(True) # Block-level reduction to combine results across all threads # Only count votes from threads that checked valid indices warp_has_unmasked_mask = cute.arch.vote_any_sync( thread_has_unmasked & thread_is_valid ) warp_has_masked_mask = cute.arch.vote_any_sync(thread_has_masked & thread_is_valid) # lane 0 writes the ballot mask to shared memory lane_id = tidx % 32 if lane_id == 0: # Store as Int8 reduction_buffer[warp_idx, 0] = Int8(1) if warp_has_unmasked_mask else Int8(0) reduction_buffer[warp_idx, 1] = Int8(1) if warp_has_masked_mask else Int8(0) cute.arch.sync_threads() # Thread 0 ORs all warp results together has_unmasked = Boolean(False) has_masked = Boolean(False) if tidx == 0: for w in cutlass.range(self.num_warps): if reduction_buffer[w, 0]: has_unmasked = Boolean(True) if reduction_buffer[w, 1]: has_masked = Boolean(True) # Only thread 0 updates the output arrays (common to both paths) if tidx == 0: # Block classification based on what we found: # - If has_masked and has_unmasked: partial block (needs masking) # - If only has_unmasked: full block (no masking needed) # - If only has_masked: skip this block entirely is_partial = Boolean(has_masked and has_unmasked) is_full = Boolean(has_unmasked and (not has_masked)) if is_partial: mask_idx[batch_idx, head_idx, m_block, num_mask_blocks] = n_block num_mask_blocks += 1 elif is_full and const_expr(self.compute_full_blocks): full_idx[batch_idx, head_idx, m_block, num_full_blocks] = n_block num_full_blocks += 1 # Only thread 0 writes back the counts if tidx == 0: mask_cnt[batch_idx, head_idx, m_block] = num_mask_blocks if const_expr(self.compute_full_blocks): full_cnt[batch_idx, head_idx, m_block] = num_full_blocks def compute_block_sparsity( tile_m, tile_n, batch_size, num_heads, seqlen_q, seqlen_k, mask_mod: Callable, aux_tensors: Optional[list], # list[cute.Tensor] device, compute_full_blocks: bool = True, use_fast_sampling: bool = False, ) -> Tuple[BlockSparseTensors, BlockSparseTensorsTorch]: """ Computes block sparsity for a given `mask_mod`. Args: tile_m: The tile size for the m dimension. tile_n: The tile size for the n dimension. batch_size: The batch size. num_heads: The number of heads. seqlen_q: The sequence length for the query. seqlen_k: The sequence length for the key. mask_mod: The `mask_mod` callable to use. aux_tensors: A list of auxiliary tensors. device: The device to use. compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: A tuple of `BlockSparseTensors` and `BlockSparseTensorsTorch`. """ # Check if mask_mod is marked as suitable for 5-point fast sampling use_fast_sampling = getattr(mask_mod, "use_fast_sampling", use_fast_sampling) num_m_blocks = (seqlen_q + tile_m - 1) // tile_m num_n_blocks = (seqlen_k + tile_n - 1) // tile_n mask_block_cnt = torch.zeros( (batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32 ) mask_block_idx = torch.zeros( (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 ) full_block_cnt = ( torch.zeros((batch_size, num_heads, num_m_blocks), device=device, dtype=torch.int32) if compute_full_blocks else None ) full_block_idx = ( torch.zeros( (batch_size, num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32 ) if compute_full_blocks else None ) blocksparse_tensors_torch = BlockSparseTensorsTorch( mask_block_cnt=mask_block_cnt, mask_block_idx=mask_block_idx, full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, block_size=(tile_m, tile_n), ) mask_mod_hash = hash_callable(mask_mod) blocksparse_tensors = to_cute_block_sparse_tensors( blocksparse_tensors_torch, enable_tvm_ffi=True ) compile_key = ( tile_m, tile_n, mask_mod_hash, compute_full_blocks, aux_tensors is not None, use_fast_sampling, ) if compile_key not in compute_block_sparsity.compile_cache: kernel = BlockSparsityKernel( mask_mod, tile_mn=(tile_m, tile_n), compute_full_blocks=compute_full_blocks, use_aux_tensors=aux_tensors is not None, use_fast_sampling=use_fast_sampling, ) compute_block_sparsity.compile_cache[compile_key] = cute.compile( kernel, blocksparse_tensors, seqlen_q, seqlen_k, aux_tensors, options="--enable-tvm-ffi" ) compute_block_sparsity.compile_cache[compile_key]( blocksparse_tensors_torch[:4], seqlen_q, seqlen_k, aux_tensors, ) return blocksparse_tensors, blocksparse_tensors_torch compute_block_sparsity.compile_cache = {} ================================================ FILE: flash_attn/cute/copy_utils.py ================================================ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math from typing import Optional, Type, Callable import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.utils.blackwell_helpers as sm100_utils from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import llvm import cutlass.pipeline @dsl_user_op def cvt_copy( atom: cute.CopyAtom, src: cute.Tensor, dst: cute.Tensor, *, pred: Optional[cute.Tensor] = None, loc=None, ip=None, **kwargs, ) -> None: assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem if const_expr(src.element_type != dst.element_type): src_cvt = cute.make_fragment_like(src, dst.element_type, loc=loc, ip=ip) src_cvt.store(src.load().to(dst.element_type)) src = src_cvt cute.copy(atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) @dsl_user_op def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) cute.autovec_copy(src, dst, loc=loc, ip=ip) return dst @dsl_user_op def get_copy_atom( dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None ) -> cute.CopyAtom: num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) @dsl_user_op def make_tmem_copy( tmem_copy_atom: cute.CopyAtom, num_wg: int = 1, *, loc=None, ip=None ) -> cute.CopyAtom: num_dp, num_bits, num_rep, _ = sm100_utils.get_tmem_copy_properties(tmem_copy_atom) assert num_dp == 32 assert num_bits == 32 tiler_mn = (cute.make_layout((128 * num_rep * num_wg // 32, 32), stride=(32, 1)),) layout_tv = cute.make_layout( ((32, 4, num_wg), (num_rep, 32)), stride=((0, 1, 4 * num_rep), (4, 4 * num_rep * num_wg)) ) return cute.make_tiled_copy(tmem_copy_atom, layout_tv, tiler_mn) @dsl_user_op def copy( src: cute.Tensor, dst: cute.Tensor, *, pred: Optional[cute.Tensor] = None, num_copy_elems: int = 1, is_async: bool = False, loc=None, ip=None, **kwargs, ) -> None: copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) def tiled_copy_1d( dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False ) -> cute.TiledCopy: num_copy_bits = num_copy_elems * dtype.width copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) thr_layout = cute.make_layout(num_threads) val_layout = cute.make_layout(num_copy_elems) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) def tiled_copy_2d( dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False ) -> cute.TiledCopy: num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width copy_elems = num_copy_bits // dtype.width copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) gmem_threads_per_row = major_mode_size // copy_elems assert num_threads % gmem_threads_per_row == 0 thr_layout = cute.make_ordered_layout( (num_threads // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), ) val_layout = cute.make_layout((1, copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) @dsl_user_op def atomic_add_fp32x4( a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None ) -> None: gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # cache_hint = cutlass.Int64(0x12F0000000000000) llvm.inline_asm( None, [ gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), Float32(c).ir_value(loc=loc, ip=ip), Float32(d).ir_value(loc=loc, ip=ip), ], # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], "{\n\t" # ".reg .b128 abcd;\n\t" # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" ".reg .v4 .f32 abcd;\n\t" # "mov.b128 abcd, {$1, $2, $3, $4};\n\t" "mov.f32 abcd.x, $1;\n\t" "mov.f32 abcd.y, $2;\n\t" "mov.f32 abcd.z, $3;\n\t" "mov.f32 abcd.w, $4;\n\t" "red.global.add.v4.f32 [$0], abcd;\n\t" # "red.global.add.L2::cache_hint.v4.f32 [$0], abcd, 0x14F0000000000000;\n\t" "}\n", # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", "l,f,f,f,f", # "l,f,l", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @dsl_user_op def set_block_rank( smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None ) -> Int32: """Map the given smem pointer to the address at another CTA rank in the cluster.""" smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() return Int32( llvm.inline_asm( T.i32(), [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], "mapa.shared::cluster.u32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def store_shared_remote_fp32x4( a: Float32, b: Float32, c: Float32, d: Float32, smem_ptr: cute.Pointer, mbar_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None, ) -> None: remote_smem_ptr_i32 = set_block_rank( smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip ).ir_value() remote_mbar_ptr_i32 = set_block_rank( mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip ).ir_value() llvm.inline_asm( None, [ remote_smem_ptr_i32, remote_mbar_ptr_i32, Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), Float32(c).ir_value(loc=loc, ip=ip), Float32(d).ir_value(loc=loc, ip=ip), ], "{\n\t" ".reg .v4 .f32 abcd;\n\t" "mov.f32 abcd.x, $2;\n\t" "mov.f32 abcd.y, $3;\n\t" "mov.f32 abcd.z, $4;\n\t" "mov.f32 abcd.w, $5;\n\t" "st.async.shared::cluster.mbarrier::complete_tx::bytes.v4.f32 [$0], abcd, [$1];\n\t" "}\n", "r,r,f,f,f,f", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @dsl_user_op def cpasync_bulk_s2cluster( smem_src_ptr: cute.Pointer, smem_dst_ptr: cute.Pointer, mbar_ptr: cute.Pointer, size: int | Int32, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None, ): smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value() smem_dst_ptr_i32 = set_block_rank( smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip ).ir_value() mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value() llvm.inline_asm( None, [ smem_dst_ptr_i32, smem_src_ptr_i32, mbar_ptr_i32, Int32(size).ir_value(loc=loc, ip=ip), ], "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];", "r,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @dsl_user_op def cpasync_bulk_g2s( gmem_ptr: cute.Pointer, smem_ptr: cute.Pointer, tma_bar_ptr: cute.Pointer, size: int | Int32, *, loc=None, ip=None, ): gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() mbar_ptr_i32 = tma_bar_ptr.toint(loc=loc, ip=ip).ir_value() llvm.inline_asm( None, [gmem_ptr_i64, smem_ptr_i32, mbar_ptr_i32, Int32(size).ir_value()], "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$0], $3, [$2];", "l,r,r,r", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) @dsl_user_op def cpasync_reduce_bulk_add_f32( smem_ptr: cute.Pointer, gmem_ptr: cute.Pointer, store_bytes: int | Int32, *, loc=None, ip=None, ): smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() # cache_hint = cutlass.Int64(0x14F0000000000000) # EVICT_LAST llvm.inline_asm( None, [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value()], "cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [$0], [$1], $2;", "l,r,r", # [gmem_ptr.llvm_ptr, smem_ptr_i32, Int32(store_bytes).ir_value(), cache_hint.ir_value()], # "cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [$0], [$1], $2, $3;", # "l,r,r,l", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) def cpasync_bulk_get_copy_fn( src_tensor: cute.Tensor, dst_tensor: cute.Tensor, single_stage: bool = False, **kwargs, ) -> Callable: # src_is_smem = const_expr( # isinstance(src_tensor.iterator, cute.Pointer) # and src_tensor.memspace == cute.AddressSpace.smem # ) group_rank_src = const_expr(cute.rank(src_tensor) - (1 if not single_stage else 0)) group_rank_dst = const_expr(cute.rank(dst_tensor) - (1 if not single_stage else 0)) # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) src = cute.group_modes(src_tensor, 0, group_rank_src) dst = cute.group_modes(dst_tensor, 0, group_rank_dst) def copy_bulk(src_idx, dst_idx, **new_kwargs): size = const_expr(cute.size(src.shape[:-1]) * src.element_type.width // 8) cpasync_bulk_g2s( src[None, src_idx].iterator, dst[None, dst_idx].iterator, size=size, **new_kwargs, **kwargs, ) def copy_bulk_single_stage(**new_kwargs): size = const_expr(cute.size(src.shape) * src.element_type.width // 8) cpasync_bulk_g2s(src.iterator, dst.iterator, size=size, **new_kwargs, **kwargs) return copy_bulk if const_expr(not single_stage) else copy_bulk_single_stage def tma_get_copy_fn( atom: cute.CopyAtom, cta_coord: cute.Coord, cta_layout: cute.Layout, src_tensor: cute.Tensor, dst_tensor: cute.Tensor, filter_zeros: bool = False, single_stage: bool = False, **kwargs, ) -> Callable: src_is_smem = const_expr( isinstance(src_tensor.iterator, cute.Pointer) and src_tensor.memspace == cute.AddressSpace.smem ) smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) s, g = cpasync.tma_partition( atom, cta_coord, cta_layout, cute.group_modes(smem_tensor, 0, group_rank_smem), cute.group_modes(gmem_tensor, 0, group_rank_gmem), ) if const_expr(filter_zeros): s = cute.filter_zeros(s) g = cute.filter_zeros(g) src, dst = (s, g) if src_is_smem else (g, s) def copy_tma(src_idx, dst_idx, **new_kwargs): cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) def copy_tma_single_stage(**new_kwargs): cute.copy(atom, src, dst, **new_kwargs, **kwargs) return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): copy( src_idx=src_idx, dst_idx=producer_state.index, tma_bar_ptr=pipeline.producer_get_barrier(producer_state), **new_kwargs, ) return copy_fn ================================================ FILE: flash_attn/cute/cute_dsl_ptxas.py ================================================ """ System ptxas replacement for CUTLASS DSL. Environment variables: CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas) CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output """ import os import sys import re import ctypes import subprocess from pathlib import Path import cutlass CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None) VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1" _original_load_cuda_library = None _user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1 def _log(msg): if VERBOSE: print(f"[ptxas] {msg}", file=sys.stderr) def _get_ptx(compiled_func) -> tuple[str, Path] | None: """Find and read PTX file, stripping null bytes.""" func_name = getattr(compiled_func, "function_name", None) if not func_name: return None dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()) for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"): content = ptx_path.read_text().rstrip("\x00") if ".entry " in content and content.rstrip().endswith("}"): _log(f"Found PTX: {ptx_path}") return content, ptx_path return None def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes: """Compile PTX to cubin using system ptxas.""" # Extract arch from PTX match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content) arch = match.group(1) if match else "sm_90a" # Write stripped content back if needed if ptx_path.read_text() != ptx_content: ptx_path.write_text(ptx_content) # Compile cubin_tmp = ptx_path.with_suffix(".cubin.tmp") try: assert CUTE_DSL_PTXAS_PATH is not None result = subprocess.run( [CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)], capture_output=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"ptxas failed: {result.stderr}") cubin_data = cubin_tmp.read_bytes() _log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})") # Save cubin if CUTE_DSL_KEEP_CUBIN is set if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1": cubin_out = ptx_path.with_suffix(".cubin") cubin_out.write_bytes(cubin_data) _log(f"Saved: {cubin_out}") return cubin_data finally: cubin_tmp.unlink(missing_ok=True) def _patched_load_cuda_library(self): """Replacement for _load_cuda_library that uses system ptxas.""" result = _get_ptx(self) if not result: _log("PTX not found, falling back to embedded ptxas") return _original_load_cuda_library(self) ptx_content, ptx_path = result try: cubin = _compile_ptx(ptx_path, ptx_content) except Exception as e: _log(f"Compilation failed ({e}), falling back to embedded ptxas") return _original_load_cuda_library(self) # Load cubin import cuda.bindings.runtime as cuda_runtime err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0) if err != cuda_runtime.cudaError_t.cudaSuccess: _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas") return _original_load_cuda_library(self) # Register kernels on all devices _, cuda_load_to_device = self._get_cuda_init_and_load() lib_ptr = ctypes.c_void_p(int(library)) dev_id = ctypes.c_int32(0) err_val = ctypes.c_int32(0) args = (ctypes.c_void_p * 3)( ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p), ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p), ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p), ) for dev in range(self.num_devices): dev_id.value = dev cuda_load_to_device(args) if err_val.value != 0: _log("cuda_load_to_device failed, falling back to embedded ptxas") return _original_load_cuda_library(self) _log(f"Loaded kernel from {ptx_path.name}") # Delete PTX if user didn't originally want it kept if not _user_wanted_ptx: ptx_path.unlink(missing_ok=True) return [cuda_runtime.cudaLibrary_t(lib_ptr.value)] def patch(): """Install system ptxas hook. Call before importing cutlass.""" global _original_load_cuda_library, _user_wanted_ptx assert CUTE_DSL_PTXAS_PATH is not None if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK): raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}") # Track if user originally wanted PTX kept _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1" # os.environ['CUTE_DSL_KEEP_PTX'] = '1' assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", ( "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas" ) cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction _original_load_cuda_library = cls._load_cuda_library cls._load_cuda_library = _patched_load_cuda_library _log("Patch applied") return ================================================ FILE: flash_attn/cute/cute_dsl_utils.py ================================================ # Copyright (c) 2025, Tri Dao. import os import pathlib from typing import Tuple from functools import partial, lru_cache from dataclasses import dataclass, fields import torch try: from triton.tools.disasm import extract except ImportError: extract = None import cutlass import cutlass.cute as cute from cutlass.base_dsl.typing import JitArgument from cutlass.cutlass_dsl import NumericMeta from cutlass.cute.runtime import from_dlpack StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data cute_compile_og = cute.compile torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, } @lru_cache def get_max_active_clusters(cluster_size): return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) @lru_cache def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) @dataclass class ArgumentsBase(JitArgument): def __c_pointers__(self): all_fields = [getattr(self, field.name) for field in fields(self)] non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] c_ptrs = [] for obj in non_constexpr_fields: if hasattr(obj, "__c_pointers__"): c_ptrs.extend(obj.__c_pointers__()) return c_ptrs def __get_mlir_types__(self): all_fields = [getattr(self, field.name) for field in fields(self)] non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] types, self._values_pos = [], [] for obj in non_constexpr_fields: if hasattr(obj, "__get_mlir_types__"): obj_types = obj.__get_mlir_types__() types.extend(obj_types) self._values_pos.append(len(obj_types)) else: self._values_pos.append(0) return types def __new_from_mlir_values__(self, values): all_fields = {field.name: getattr(self, field.name) for field in fields(self)} constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} non_constexpr_fields = { n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) } for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) values = values[n_items:] return self.__class__(**non_constexpr_fields, **constexpr_fields) def load_cubin_module_data_patched(cubin_data, filepath): pathlib.Path(filepath).write_bytes(cubin_data) return load_cubin_module_data_og(cubin_data) def cute_compile_patched(*args, **kwargs): """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" cubin_path = os.getenv("CUTE_CUBIN_PATH", None) if cubin_path is not None: cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( load_cubin_module_data_patched, filepath=cubin_path ) output = cute_compile_og(*args, **kwargs) if cubin_path is not None: cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og if extract is not None: sass = extract(cubin_path, None) pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) return output def assume_strides_aligned(t): """Assume all strides except the last are divisible by 128 bits. Python int strides (e.g., stride=0 from GQA expand) are kept as-is since they're static and don't need alignment assumptions. """ divby = 128 // t.element_type.width strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1]) return (*strides, t.stride[-1]) def assume_tensor_aligned(t): """Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None.""" if t is None: return None return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t))) def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) if fully_dynamic: return tensor.mark_layout_dynamic() if leading_dim == -1: leading_dim = t.ndim - 1 return tensor.mark_layout_dynamic(leading_dim=leading_dim) def to_cute_aux_tensor(t, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors. This allows the user to specify alignment and leading dimension for aux tensors used in custom score_mod callables. """ assumed_align: int = getattr(t, "__assumed_align__", None) leading_dim: int = getattr(t, "__leading_dim__", None) fully_dynamic: bool = leading_dim is None return to_cute_tensor( t, assumed_align=assumed_align, leading_dim=leading_dim, fully_dynamic=fully_dynamic, enable_tvm_ffi=enable_tvm_ffi, ) def get_aux_tensor_metadata(aux_tensors): return tuple( ( getattr(t, "__assumed_align__", 0), getattr(t, "__leading_dim__", -1), hasattr(t, "__leading_dim__"), ) for t in aux_tensors ) def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: """Return tuple of bools indicating which dims have stride=0 (broadcast). This is useful for compile keys since CuTe's mark_layout_dynamic() keeps stride=0 as static, meaning kernels compiled with different broadcast patterns are not interchangeable. """ return tuple(s == 0 for s in tensor.stride()) ================================================ FILE: flash_attn/cute/fa_logging.py ================================================ # Copyright (c) 2025, Tri Dao. """Unified FlashAttention logging controlled by a single ``FA_LOG_LEVEL`` env var. Host-side messages go through Python ``logging`` (logger name ``flash_attn``). A default ``StreamHandler`` is attached automatically when ``FA_LOG_LEVEL >= 1`` so that standalone scripts get output without extra setup; applications that configure their own logging can remove or replace it via the standard API. FA_LOG_LEVEL mapping:: 0 off nothing logged 1 host host-side summaries only (no kernel printf) 2 kernel host + curated kernel traces 3 max host + all kernel traces (noisy, perf hit) Set via environment variable:: FA_LOG_LEVEL=1 python train.py Device-side ``cute.printf`` calls are compile-time eliminated via ``cutlass.const_expr`` when the log level is below the callsite threshold, so there is zero performance cost when device logging is off. Changing the log level after kernel compilation requires a recompile (the level participates in the forward compile key). """ import logging import os import sys import cutlass.cute as cute from cutlass import const_expr _LOG_LEVEL_NAMES = {"off": 0, "host": 1, "kernel": 2, "max": 3} def _parse_log_level(raw: str) -> int: if raw in _LOG_LEVEL_NAMES: return _LOG_LEVEL_NAMES[raw] try: level = int(raw) except ValueError: return 0 return max(0, min(level, 3)) _fa_log_level: int = _parse_log_level(os.environ.get("FA_LOG_LEVEL", "0")) _logger = logging.getLogger("flash_attn") _logger.addHandler(logging.NullHandler()) _default_handler: logging.Handler | None = None def _configure_default_handler() -> None: global _default_handler if _fa_log_level >= 1: if _default_handler is None: _default_handler = logging.StreamHandler(sys.stdout) _default_handler.setFormatter(logging.Formatter("[FA] %(message)s")) _logger.addHandler(_default_handler) _logger.setLevel(logging.DEBUG) else: if _default_handler is not None: _logger.removeHandler(_default_handler) _default_handler = None _logger.setLevel(logging.WARNING) _configure_default_handler() def get_fa_log_level() -> int: return _fa_log_level def set_fa_log_level(level: int | str) -> None: """Set the FA log level programmatically. Host logging takes effect immediately. Device logging changes only affect kernels compiled after this call (new compile-key selection). """ global _fa_log_level if isinstance(level, str): level = _parse_log_level(level) _fa_log_level = max(0, min(int(level), 3)) _configure_default_handler() def fa_log(level: int, msg: str): if _fa_log_level >= level: _logger.info(msg) def fa_printf(level: int, fmt, *args): if const_expr(_fa_log_level >= level): cute.printf(fmt, *args) ================================================ FILE: flash_attn/cute/fast_math.py ================================================ # Copyright (c) 2025, Tri Dao. import cutlass import cutlass.cute as cute from cutlass import Int32 @cute.jit def clz(x: Int32) -> Int32: # for i in cutlass.range_constexpr(32): # if (1 << (31 - i)) & x: # return Int32(i) # return Int32(32) # Early exit is not supported yet res = Int32(32) done = False for i in cutlass.range(32): if ((1 << (31 - i)) & x) and not done: res = Int32(i) done = True return res ================================================ FILE: flash_attn/cute/flash_bwd.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_bwd_sm80.hpp # from Cutlass C++ to Cute-DSL. import math from types import SimpleNamespace from typing import Type, Callable, Optional from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, warp from cutlass import Float32, Int32 import cutlass.utils as utils_basic from quack import layout_utils from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, m_block_size: int = 64, n_block_size: int = 128, num_stages_Q: int = 2, num_stages_dO: int = 2, num_threads: int = 256, pack_gqa: bool = False, is_causal: bool = False, SdP_swapAB: bool = False, dKV_swapAB: bool = False, dQ_swapAB: bool = False, AtomLayoutMSdP: int = 1, AtomLayoutNdKV: int = 8, AtomLayoutMdQ: int = 1, V_in_regs: bool = False, ): """Initializes the configuration for a flash attention v2 kernel. All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension should be a multiple of 8. :param head_dim: head dimension :type head_dim: int :param m_block_size: m block size :type m_block_size: int :param n_block_size: n block size :type n_block_size: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication self.check_hdim_oob = head_dim != self.head_dim_padded self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.qhead_per_kvhead = qhead_per_kvhead self.m_block_size = m_block_size self.n_block_size = n_block_size self.num_threads = num_threads self.pack_gqa = pack_gqa self.is_causal = is_causal self.num_stages_Q = num_stages_Q self.num_stages_dO = num_stages_dO self.SdP_swapAB = SdP_swapAB self.dKV_swapAB = dKV_swapAB self.dQ_swapAB = dQ_swapAB self.AtomLayoutMSdP = AtomLayoutMSdP self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ num_mma_warps = self.num_threads // cute.arch.WARP_SIZE self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB self.V_in_regs = V_in_regs self.share_QV_smem = V_in_regs @staticmethod def can_implement( dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO, num_threads, is_causal, V_in_regs=False ) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int :param m_block_size: m block size :type m_block_size: int :param n_block_size: n block size :type n_block_size: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal :type is_causal: bool :return: True if the kernel can be implemented, False otherwise :rtype: bool """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if head_dim_v % 8 != 0: return False if n_block_size % 16 != 0: return False if num_threads % 32 != 0: return False # Check if block size setting is out of shared memory capacity # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 smem_usage_K = n_block_size * head_dim * 2 smem_usage_V = n_block_size * head_dim_v * 2 smem_usage_QV = (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False return True def _check_type( self, mQ_type: Type[cutlass.Numeric], mK_type: Type[cutlass.Numeric], mV_type: Type[cutlass.Numeric], mdO_type: Type[cutlass.Numeric], mLSE_type: Type[cutlass.Numeric], mdPsum_type: Type[cutlass.Numeric], mdQaccum_type: Type[cutlass.Numeric], mdK_type: Type[cutlass.Numeric], mdV_type: Type[cutlass.Numeric], mCuSeqlensQ_type: Type[cutlass.Numeric] | None, mCuSeqlensK_type: Type[cutlass.Numeric] | None, mSeqUsedQ_type: Type[cutlass.Numeric] | None, mSeqUsedK_type: Type[cutlass.Numeric] | None, ): if cutlass.const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): raise TypeError("All tensors must have the same data type") if cutlass.const_expr(self.qhead_per_kvhead == 1): if cutlass.const_expr(not (mdK_type == mdV_type == mQ_type)): raise TypeError("mdK and mdV tensors must have the same data type as mQ") else: if cutlass.const_expr(not (mdK_type == mdV_type == cutlass.Float32)): raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") if cutlass.const_expr(not mQ_type in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if cutlass.const_expr(not mLSE_type in [cutlass.Float32]): raise TypeError("LSE tensor must be Float32") if cutlass.const_expr(not mdPsum_type in [cutlass.Float32]): raise TypeError("dPsum tensor must be Float32") if cutlass.const_expr(not mdQaccum_type in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") if cutlass.const_expr(mCuSeqlensQ_type not in [None, cutlass.Int32]): raise TypeError("cuSeqlensQ tensor must be Int32") if cutlass.const_expr(mCuSeqlensK_type not in [None, cutlass.Int32]): raise TypeError("cuSeqlensK tensor must be Int32") if cutlass.const_expr(mSeqUsedQ_type not in [None, cutlass.Int32]): raise TypeError("SeqUsedQ tensor must be Int32") if cutlass.const_expr(mSeqUsedK_type not in [None, cutlass.Int32]): raise TypeError("SeqUsedK tensor must be Int32") assert mQ_type == self.dtype def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: Q/K/V # /////////////////////////////////////////////////////////////////////////////// sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_padded) self.sQ_layout = cute.tile_to_shape( sQ_layout_atom, (self.m_block_size, self.head_dim_padded, self.num_stages_Q), (0, 1, 2), ) sK_layout_atom = sQ_layout_atom self.sK_layout = cute.tile_to_shape( sK_layout_atom, (self.n_block_size, self.head_dim_padded), (0, 1), ) sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.head_dim_v_padded) self.sV_layout = cute.tile_to_shape( sV_layout_atom, (self.n_block_size, self.head_dim_v_padded), (0, 1), ) sdO_layout_atom = sV_layout_atom self.sdO_layout = cute.tile_to_shape( sdO_layout_atom, (self.m_block_size, self.head_dim_v_padded, self.num_stages_dO), (0, 1, 2), ) # TODO: do we set swizzle to be 3 here explicitly? sPdS_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.n_block_size) self.sPdS_layout = cute.tile_to_shape( sPdS_layout_atom, (self.m_block_size, self.n_block_size), (0, 1), ) # We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, # it's still a valid smem address. self.sLSE_layout = cute.make_layout( (self.m_block_size, self.num_stages_Q), stride=(1, cute.round_up(self.m_block_size, 64)), ) sLSEMma_layout = cute.make_layout( (self.m_block_size, self.n_block_size, self.num_stages_Q), stride=(1, 0, cute.round_up(self.m_block_size, 64)), ) sLSEMma_layout_transposed = cute.make_layout( (self.n_block_size, self.m_block_size, self.num_stages_Q), stride=(0, 1, cute.round_up(self.m_block_size, 64)), ) self.sLSEMma_layout = sLSEMma_layout if not self.SdP_swapAB else sLSEMma_layout_transposed # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.dtype.width # atom_async_copy: async copy atom for QKV load atom_async_copy = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.dtype, num_bits_per_copy=universal_copy_bits, ) # atom_universal_copy: universal copy atom for O store atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, ) # tQK_layout: thread layout for QK load tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems assert self.num_threads % tQK_shape_dim_1 == 0, "num_threads must be divisible by tQK_shape_dim_1" tQK_layout = cute.make_ordered_layout( (self.num_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), ) # Do we need to check if we overshot kBlockM when we load Q? self.is_even_m_smem_q = self.m_block_size % tQK_layout.shape[0] == 0 # Do we need to check if we overshot kBlockN when we load K? self.is_even_n_smem_k = self.n_block_size % tQK_layout.shape[0] == 0 tVdO_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems assert self.num_threads % tVdO_shape_dim_1 == 0, "num_threads must be divisible by tVdO_shape_dim_1" tVdO_layout = cute.make_ordered_layout( (self.num_threads // tVdO_shape_dim_1, tVdO_shape_dim_1), order=(1, 0), ) # Do we need to check if we overshot kBlockN when we load V? self.is_even_n_smem_v = self.n_block_size % tVdO_layout.shape[0] == 0 self.is_even_m_smem_do = self.m_block_size % tVdO_layout.shape[0] == 0 # Value layouts for copies vQKVdO_layout = cute.make_layout((1, async_copy_elems)) # gmem_tiled_copy_QK: tiled copy for QK load self.gmem_tiled_copy_QK = cute.make_tiled_copy_tv(atom_async_copy, tQK_layout, vQKVdO_layout) self.gmem_tiled_copy_VdO = cute.make_tiled_copy_tv(atom_async_copy, tVdO_layout, vQKVdO_layout) self.gmem_tiled_copy_dK = cute.make_tiled_copy_tv(atom_universal_copy, tQK_layout, vQKVdO_layout) self.gmem_tiled_copy_dV = cute.make_tiled_copy_tv(atom_universal_copy, tVdO_layout, vQKVdO_layout) async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width # I think we wouldn't require this with smarter padding if cutlass.const_expr(not self.varlen_q): async_copy_elems_accum = universal_copy_bits // cutlass.Float32.width atom_async_copy_accum = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), cutlass.Float32, num_bits_per_copy=universal_copy_bits, ) else: async_copy_elems_accum = 1 atom_async_copy_accum = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width, ) self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( atom_async_copy_accum, cute.make_layout(self.num_threads), cute.make_layout(async_copy_elems_accum), ) self.gmem_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=cutlass.Float32.width ), cute.make_layout(self.num_threads), cute.make_layout(1) ) if cutlass.const_expr(self.qhead_per_kvhead > 1): self.gmem_tiled_copy_dK = self.gmem_tiled_copy_dQaccum self.gmem_tiled_copy_dV = self.gmem_tiled_copy_dQaccum def _get_tiled_mma(self): num_mma_warps = self.num_threads // 32 AtomLayoutSdP = (self.AtomLayoutMSdP, num_mma_warps // self.AtomLayoutMSdP, 1) if cutlass.const_expr(not self.SdP_swapAB) else (num_mma_warps // self.AtomLayoutMSdP, self.AtomLayoutMSdP, 1) tiled_mma_sdp = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutSdP, permutation_mnk=(AtomLayoutSdP[0] * 16, AtomLayoutSdP[1] * 16, 16), ) AtomLayoutdKV = (self.AtomLayoutNdKV, num_mma_warps // self.AtomLayoutNdKV, 1) if cutlass.const_expr(not self.dKV_swapAB) else (num_mma_warps // self.AtomLayoutNdKV, self.AtomLayoutNdKV, 1) tiled_mma_dkv = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdKV, permutation_mnk=(AtomLayoutdKV[0] * 16, AtomLayoutdKV[1] * 16, 16), ) AtomLayoutdQ = (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if cutlass.const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) tiled_mma_dq = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, cutlass.Float32, (16, 8, 16)), AtomLayoutdQ, permutation_mnk=(AtomLayoutdQ[0] * 16, AtomLayoutdQ[1] * 16, 16), ) return tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct, sdO_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout) ] cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] sLSE_struct, sdPsum_struct = [ cute.struct.Align[cute.struct.MemRange[cutlass.Float32, cute.cosize(layout)], 128] for layout in (self.sLSE_layout, self.sLSE_layout) ] sP_struct, sdS_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 128] for layout in (self.sPdS_layout, self.sPdS_layout) ] @cute.struct class SharedStorageSeparateQV: sK: sK_struct sV: sV_struct sQ: sQ_struct sdO: sdO_struct sLSE: sLSE_struct sdPsum: sdPsum_struct sP: sP_struct sdS: sdS_struct # TODO: the case where there's no sP @cute.struct class SharedStorageSharedQV: sK: sK_struct sV: sV_struct sQ: sQV_struct sdO: sdO_struct sLSE: sLSE_struct sdPsum: sdPsum_struct sP: sP_struct sdS: sdS_struct return SharedStorageSeparateQV if cutlass.const_expr(not self.share_QV_smem) else SharedStorageSharedQV @cute.jit def __call__( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): assert mdQ_semaphore is None and mdK_semaphore is None and mdV_semaphore is None, ( "determinism not supported yet for Sm80" ) # Get the data type and check if it is fp16 or bf16 self._check_type(*(t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK))) mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) ] self.varlen_q = (mCuSeqlensQ is not None) self._setup_attributes() SharedStorage = self._get_shared_storage_cls() tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq = self._get_tiled_mma() num_head = mQ.shape[1] if cutlass.const_expr(mCuSeqlensQ is not None) else mQ.shape[2] if cutlass.const_expr(mCuSeqlensK is not None): TileScheduler = SingleTileVarlenScheduler num_batch = mCuSeqlensK.shape[0] - 1 else: TileScheduler = SingleTileScheduler num_batch = mK.shape[0] # Uses seqlen k, etc. since main bwd kernel's blocks are over n tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mK.shape[1], self.n_block_size), num_head=num_head, num_batch=num_batch, num_splits=1, seqlen_k=0, headdim=mK.shape[2], headdim_v=mV.shape[2], total_q=mK.shape[0], tile_shape_mn=(self.n_block_size, self.m_block_size), qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) softmax_scale_log2 = softmax_scale * math.log2(math.e) self.kernel( mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, softmax_scale, softmax_scale_log2, self.sQ_layout, self.sK_layout, self.sV_layout, self.sdO_layout, self.sPdS_layout, self.sLSE_layout, self.sLSEMma_layout, self.gmem_tiled_copy_QK, self.gmem_tiled_copy_VdO, self.gmem_tiled_copy_dK, self.gmem_tiled_copy_dV, self.gmem_tiled_copy_LSE, self.gmem_tiled_copy_dQaccum, tiled_mma_sdp, tiled_mma_dkv, tiled_mma_dq, SharedStorage, tile_sched_params, TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], smem=SharedStorage.size_in_bytes(), stream=stream, ) @cute.kernel def kernel( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sdO_layout: cute.ComposedLayout, sPdS_layout: cute.ComposedLayout, sLSE_layout: cute.Layout, sLSEMma_layout: cute.Layout, gmem_tiled_copy_QK: cute.TiledCopy, gmem_tiled_copy_VdO: cute.TiledCopy, gmem_tiled_copy_dK: cute.TiledCopy, gmem_tiled_copy_dV: cute.TiledCopy, gmem_tiled_copy_LSE: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, tiled_mma_sdp: cute.TiledMma, tiled_mma_dkv: cute.TiledMma, tiled_mma_dq: cute.TiledMma, SharedStorage: cutlass.Constexpr, tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() n_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: seqlen = SeqlenInfoQK.create( batch_idx, mQ.shape[1], mK.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, tile_m=self.m_block_size, tile_n=self.n_block_size, ) m_block_max = cute.ceil_div(seqlen.seqlen_q, self.m_block_size) m_block_min = 0 if cutlass.const_expr(self.is_causal): m_block_min = max( (n_block * self.n_block_size + seqlen.seqlen_q - seqlen.seqlen_k) // self.m_block_size, m_block_min, ) # TODO: return early if m_block_max == 0 # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// blkQ_shape = (self.m_block_size, self.head_dim_padded) blkK_shape = (self.n_block_size, self.head_dim_padded) blkV_shape = (self.n_block_size, self.head_dim_v_padded) blkdO_shape = (self.m_block_size, self.head_dim_v_padded) if cutlass.const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[batch_idx, None, head_idx, None] mLSE_cur = mLSE[batch_idx, head_idx, None] mdO_cur = mdO[batch_idx, None, head_idx, None] mdPsum_cur = mdPsum[batch_idx, head_idx, None] mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] else: padded_offset_q = seqlen.padded_offset_q mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, head_idx, None]) mLSE_cur = cute.domain_offset((padded_offset_q,), mLSE[head_idx, None]) mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None]) mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None]) mdQaccum_cur = cute.domain_offset((padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]) head_idx_kv = head_idx // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else head_idx if cutlass.const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mK, mV)] else: mK_cur, mV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mK, mV)] # (m_block_size, head_dim, m_block) gQ = cute.local_tile(mQ_cur, blkQ_shape, (None, 0)) # (n_block_size, head_dim) gK = cute.local_tile(mK_cur, blkK_shape, (n_block, 0)) # (n_block_size, head_dim_v) gV = cute.local_tile(mV_cur, blkV_shape, (n_block, 0)) # (m_block_size, head_dim_v, m_block) gdO = cute.local_tile(mdO_cur, blkdO_shape, (None, 0)) gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (None,)) gdPsum = cute.local_tile(mdPsum_cur, (self.m_block_size,), (None,)) gdQaccum = cute.local_tile(mdQaccum_cur, (self.m_block_size * self.head_dim_padded,), (None,)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) sQ = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout) if cutlass.const_expr(not self.share_QV_smem): sV = storage.sV.get_tensor(sV_layout) else: sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) sdO = storage.sdO.get_tensor(sdO_layout) sP = storage.sP.get_tensor(sPdS_layout) sdS = storage.sdS.get_tensor(sPdS_layout) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sLSE_layout) sLSEMma = storage.sLSE.get_tensor(sLSEMma_layout) sdPsumMma = storage.sdPsum.get_tensor(sLSEMma_layout) # Transpose view of tensors for tiled mma sQt, sdOt, sKt, sPt, sdSt = [layout_utils.transpose_view(t) for t in (sQ, sdO, sK, sP, sdS)] gmem_thr_copy_QK = gmem_tiled_copy_QK.get_slice(tidx) gmem_thr_copy_VdO = gmem_tiled_copy_VdO.get_slice(tidx) gmem_thr_copy_lse = gmem_tiled_copy_LSE.get_slice(tidx) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K, m_block) tQgQ = gmem_thr_copy_QK.partition_S(gQ) tQsQ = gmem_thr_copy_QK.partition_D(sQ) # (CPY_Atom, CPY_N, CPY_K) tKgK = gmem_thr_copy_QK.partition_S(gK) tKsK = gmem_thr_copy_QK.partition_D(sK) # (CPY_Atom, CPY_N, CPY_K) tVgV = gmem_thr_copy_VdO.partition_S(gV) tVsV = gmem_thr_copy_VdO.partition_D(sV) # (CPY_Atom, CPY_M, CPY_K, m_block) tdOgdO = gmem_thr_copy_VdO.partition_S(gdO) tdOsdO = gmem_thr_copy_VdO.partition_D(sdO) tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE) tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE) tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum) tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum) tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) # /////////////////////////////////////////////////////////////////////////////// # Tile MMA compute thread partitions and allocate accumulators # /////////////////////////////////////////////////////////////////////////////// thr_mma_sdp = tiled_mma_sdp.get_slice(tidx) thr_mma_dkv = tiled_mma_dkv.get_slice(tidx) thr_mma_dq = tiled_mma_dq.get_slice(tidx) acc_shape_dK = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_padded)) acc_shape_dV = thr_mma_dkv.partition_shape_C((self.n_block_size, self.head_dim_v_padded)) acc_dK = cute.make_fragment(acc_shape_dK, cutlass.Float32) acc_dV = cute.make_fragment(acc_shape_dV, cutlass.Float32) acc_dK.fill(0.0) acc_dV.fill(0.0) tSrQ = utils.mma_make_fragment_A(sQ[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) tSrK = utils.mma_make_fragment_B(sK, thr_mma_sdp, swapAB=self.SdP_swapAB) tdPrdO = utils.mma_make_fragment_A(sdO[None, None, 0], thr_mma_sdp, swapAB=self.SdP_swapAB) tdPrV = utils.mma_make_fragment_B(sV, thr_mma_sdp, swapAB=self.SdP_swapAB) tdVrP = utils.mma_make_fragment_A(sPt, thr_mma_dkv, swapAB=self.dKV_swapAB) tdVrdO = utils.mma_make_fragment_B(sdOt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) tdKrdS = utils.mma_make_fragment_A(sdSt, thr_mma_dkv, swapAB=self.dKV_swapAB) tdKrQ = utils.mma_make_fragment_B(sQt[None, None, 0], thr_mma_dkv, swapAB=self.dKV_swapAB) tdQrdS = utils.mma_make_fragment_A(sdS, thr_mma_dq, swapAB=self.dQ_swapAB) tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB) LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None) tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice] tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice] # /////////////////////////////////////////////////////////////////////////////// # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom = cute.make_copy_atom( warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, ) smem_copy_atom_transposed = cute.make_copy_atom( warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, ) smem_thr_copy_QdO = utils.make_tiled_copy_A( smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB ).get_slice(tidx) smem_thr_copy_KV = utils.make_tiled_copy_B( smem_copy_atom, tiled_mma_sdp, swapAB=self.SdP_swapAB ).get_slice(tidx) # TODO: should this be smem_copy_atom_transposed? smem_thr_copy_PdSt = utils.make_tiled_copy_A( smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB ).get_slice(tidx) smem_thr_copy_QdOt = utils.make_tiled_copy_B( smem_copy_atom_transposed, tiled_mma_dkv, swapAB=self.dKV_swapAB ).get_slice(tidx) smem_thr_copy_dS = utils.make_tiled_copy_A( smem_copy_atom, tiled_mma_dq, swapAB=self.dQ_swapAB ).get_slice(tidx) smem_thr_copy_Kt = utils.make_tiled_copy_B( smem_copy_atom_transposed, tiled_mma_dq, swapAB=self.dQ_swapAB ).get_slice(tidx) # TODO: what's the number of bits? What if SdP_swapAB r2s_thr_copy_PdS = cute.make_tiled_copy_C( cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width ), tiled_mma_sdp, ).get_slice(tidx) tSsQ = smem_thr_copy_QdO.partition_S(sQ) tdPsdO = smem_thr_copy_QdO.partition_S(sdO) tSsK = smem_thr_copy_KV.partition_S(sK) tdPsV = smem_thr_copy_KV.partition_S(sV) tdVsPt = smem_thr_copy_PdSt.partition_S(sPt) tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt) tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt) tdKsQt = smem_thr_copy_QdOt.partition_S(sQt) tdQsdS = smem_thr_copy_dS.partition_S(sdS) tdQsKt = smem_thr_copy_Kt.partition_S(sKt) tPsP = r2s_thr_copy_PdS.partition_D(sP) tdSsdS = r2s_thr_copy_PdS.partition_D(sdS) # /////////////////////////////////////////////////////////////////////////////// # Predicate: Mark indices that need to copy when problem_shape isn't a multiple # of tile_shape # /////////////////////////////////////////////////////////////////////////////// # Construct identity layout for KV cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) tQcQ = gmem_thr_copy_QK.partition_S(cQ) t0QcQ = gmem_thr_copy_QK.get_slice(0).partition_S(cQ) if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): tdOcdO = tQcQ t0dOcdO = t0QcQ else: cdO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) tdOcdO = gmem_thr_copy_VdO.partition_S(cdO) t0dOcdO = gmem_thr_copy_VdO.get_slice(0).partition_S(cdO) cLSE = cute.make_identity_tensor((self.m_block_size,)) tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE) # Allocate predicate tensors for m and n, here we only allocate the tile of k, and # use "if" on the mn dimension. # This is to reduce register pressure and gets 2-3% performance gain. d_head = mQ.shape[cute.rank(mQ) - 1] d_head_v = mdO.shape[cute.rank(mdO) - 1] tQpQ = utils.predicate_k(tQcQ, limit=d_head) if cutlass.const_expr(self.same_hdim_kv): tdOpdO = tQpQ else: tdOpdO = utils.predicate_k(tdOcdO, limit=d_head_v) # group parameters for compute_one_m_block mma_params = SimpleNamespace( thr_mma_sdp=thr_mma_sdp, thr_mma_dkv=thr_mma_dkv, thr_mma_dq=thr_mma_dq, tSrQ=tSrQ, tSrK=tSrK, tdPrdO=tdPrdO, tdPrV=tdPrV, tdVrP=tdVrP, tdVrdO=tdVrdO, tdKrdS=tdKrdS, tdKrQ=tdKrQ, tdQrdS=tdQrdS, tdQrK=tdQrK, acc_dK=acc_dK, acc_dV=acc_dV, ) smem_copy_params = SimpleNamespace( smem_thr_copy_QdO=smem_thr_copy_QdO, smem_thr_copy_KV=smem_thr_copy_KV, smem_thr_copy_PdSt=smem_thr_copy_PdSt, smem_thr_copy_QdOt=smem_thr_copy_QdOt, smem_thr_copy_dS=smem_thr_copy_dS, smem_thr_copy_Kt=smem_thr_copy_Kt, r2s_thr_copy_PdS=r2s_thr_copy_PdS, tSsQ=tSsQ, tSsK=tSsK, tdPsdO=tdPsdO, tdPsV=tdPsV, tSsLSEMma=tSsLSEMma, tSsdPsumMma=tSsdPsumMma, tPsP=tPsP, tdSsdS=tdSsdS, tdVsPt=tdVsPt, tdVsdOt=tdVsdOt, tdKsdSt=tdKsdSt, tdKsQt=tdKsQt, tdQsdS=tdQsdS, tdQsKt=tdQsKt, ) gmem_copy_params = SimpleNamespace( gmem_thr_copy_dQaccum=gmem_thr_copy_dQaccum, tdQgdQaccum=tdQgdQaccum ) load_Q_LSE = partial( self.load_Q_LSE, gmem_tiled_copy_QK, gmem_tiled_copy_LSE, tQgQ, tQsQ, tQcQ, t0QcQ, tQpQ, tLSEgLSE, tLSEsLSE, tLSEcLSE, seqlen=seqlen.seqlen_q ) load_dO_dPsum = partial( self.load_dO_dPsum, gmem_tiled_copy_VdO, gmem_tiled_copy_LSE, tdOgdO, tdOsdO, tdOcdO, t0dOcdO, tdOpdO, tLSEgdPsum, tLSEsdPsum, tLSEcLSE, seqlen=seqlen.seqlen_q ) compute_one_m_block = partial( self.compute_one_m_block, mma_params=mma_params, smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, m_block_max=m_block_max, softmax_scale_log2=softmax_scale_log2, ) # /////////////////////////////////////////////////////////////////////////////// # Prologue # /////////////////////////////////////////////////////////////////////////////// # Start async loads of the last mn-tile, where we take care of the mn residue self.load_V(gmem_thr_copy_VdO, tVgV, tVsV, n_block, seqlen=seqlen.seqlen_k, headdim=d_head_v) if cutlass.const_expr(self.V_in_regs): cute.arch.cp_async_commit_group() self.load_K(gmem_thr_copy_QK, tKgK, tKsK, n_block, seqlen=seqlen.seqlen_k, headdim=d_head) cute.arch.cp_async_commit_group() if cutlass.const_expr(self.V_in_regs): cute.arch.cp_async_wait_group(1) cute.arch.barrier() tdPrV_copy_view = smem_thr_copy_KV.retile(tdPrV) cute.copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view) # Sync to avoid loading Q to smem_q, which overlaps with smem_v cute.arch.barrier() m_block = m_block_min assert self.num_stages_Q >= self.num_stages_dO for stage in cutlass.range_constexpr(self.num_stages_Q): if cutlass.const_expr(self.num_stages_Q == 1 or stage < self.num_stages_Q - 1): if stage == 0 or m_block + stage < m_block_max: load_Q_LSE(m_block + stage, smem_pipe_write_q=stage) cute.arch.cp_async_commit_group() if cutlass.const_expr(stage < self.num_stages_dO): if stage == 0 or m_block + stage < m_block_max: load_dO_dPsum(m_block + stage, smem_pipe_write_q=stage) cute.arch.cp_async_commit_group() # /////////////////////////////////////////////////////////////////////////////// # Mainloop # /////////////////////////////////////////////////////////////////////////////// # Start processing of the first n-block. mask = AttentionMask(self.m_block_size, self.n_block_size, seqlen) mask_fn = partial( mask.apply_mask, n_block=n_block, thr_mma=thr_mma_sdp, batch_idx=batch_idx, head_idx=head_idx, mask_seqlen=True, mask_causal=self.is_causal ) smem_pipe_read_q = cutlass.Int32(0) smem_pipe_read_do = cutlass.Int32(0) smem_pipe_write_q = cutlass.Int32(self.num_stages_Q - 1) smem_pipe_write_do = cutlass.Int32(0) for m_tile in cutlass.range(m_block_min, m_block_max, unroll=1): compute_one_m_block( m_tile, smem_pipe_read_q, smem_pipe_read_do, smem_pipe_write_q, smem_pipe_write_do, mask_fn=mask_fn, ) smem_pipe_read_q = self.advance_pipeline(smem_pipe_read_q, self.num_stages_Q) smem_pipe_read_do = self.advance_pipeline(smem_pipe_read_do, self.num_stages_dO) smem_pipe_write_q = self.advance_pipeline(smem_pipe_write_q, self.num_stages_Q) smem_pipe_write_do = self.advance_pipeline(smem_pipe_write_do, self.num_stages_dO) # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// # If GQA, we scale dK in the postprocessing kernel instead if cutlass.const_expr(self.qhead_per_kvhead == 1): acc_dK.store(acc_dK.load() * softmax_scale) # reuse sK and sV data iterator sdK = cute.make_tensor(sK.iterator, sK_layout) sdV = cute.make_tensor(sV.iterator, sV_layout) self.epilogue( acc_dK, acc_dV, mdK, mdV, sdK, sdV, gmem_tiled_copy_dK, gmem_tiled_copy_dV, tiled_mma_dkv, tidx, n_block, head_idx, batch_idx, seqlen, d_head, d_head_v ) @cute.jit def compute_one_m_block( self, m_block: cutlass.Int32, smem_pipe_read_q: cutlass.Int32, smem_pipe_read_do: cutlass.Int32, smem_pipe_write_q: cutlass.Int32, smem_pipe_write_do: cutlass.Int32, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, gmem_copy_params: SimpleNamespace, load_Q_LSE: Callable, load_dO_dPsum: Callable, m_block_max: cutlass.Int32, softmax_scale_log2: cutlass.Float32, mask_fn: Optional[Callable] = None, ): def load_Q_next(): m_block_next = m_block + (self.num_stages_Q - 1 if cutlass.const_expr(self.num_stages_Q > 1) else 1) if m_block_next < m_block_max: load_Q_LSE(m_block_next, smem_pipe_write_q) cute.arch.cp_async_commit_group() def load_dO_next(): if m_block + self.num_stages_dO < m_block_max: load_dO_dPsum(m_block + self.num_stages_dO, smem_pipe_write_do) cute.arch.cp_async_commit_group() # MMA S acc_shape_SdP = mma_params.thr_mma_sdp.partition_shape_C( (self.m_block_size, self.n_block_size) if cutlass.const_expr(not self.SdP_swapAB) else (self.n_block_size, self.m_block_size) ) acc_S = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_S.fill(0.0) cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_Q > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_S, mma_params.tSrQ, mma_params.tSrK, smem_copy_params.tSsQ[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.tSsK, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) bidx = 0 # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE) for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True)) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # MMA dP acc_dP = cute.make_fragment(acc_shape_SdP, cutlass.Float32) acc_dP.fill(0.0) cute.arch.cp_async_wait_group(1 if cutlass.const_expr(self.num_stages_dO > 1) else 0) cute.arch.barrier() sm80_utils.gemm( mma_params.thr_mma_sdp, acc_dP, mma_params.tdPrdO, mma_params.tdPrV, smem_copy_params.tdPsdO[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.tdPsV, smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, hook_fn=load_Q_next if cutlass.const_expr(self.num_stages_Q > 1) else None, swap_AB=self.SdP_swapAB, ) tLSErdPsum = cute.make_fragment_like(smem_copy_params.tSsdPsumMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum ) acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) if cutlass.const_expr(not self.Mma_dKV_is_RS): tPrP = smem_copy_params.r2s_thr_copy_PdS.retile(rP) # ((Atom,AtomNum), MMA_N, MMA_N) cute.copy(smem_copy_params.r2s_thr_copy_PdS, tPrP, smem_copy_params.tPsP) rdS = cute.make_fragment_like(acc_dP, self.dtype) rdS.store(acc_dP.load().to(self.dtype)) if cutlass.const_expr(not self.Mma_dKV_is_RS): cute.arch.barrier() # Make sure P is written # For hdim 64, It's faster to write to smem_dS first before the dV gemm if cutlass.const_expr(not self.Mma_dKV_is_RS): tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS) cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS) if cutlass.const_expr(self.Mma_dKV_is_RS): tdVrP = layout_utils.reshape_acc_to_frgA(rP) else: tdVrP = mma_params.tdVrP # MMA dK sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dV, tdVrP, mma_params.tdVrdO, smem_copy_params.tdVsPt, smem_copy_params.tdVsdOt[None, None, None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, ) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(mma_params.acc_dV) cute.arch.barrier() # Make sure dS is written # MMA dQ def dQ_mma(hook_fn): acc_shape_dQ = mma_params.thr_mma_dq.partition_shape_C( (self.m_block_size, self.head_dim_padded) if cutlass.const_expr(not self.dQ_swapAB) else (self.head_dim_padded, self.m_block_size) ) acc_dQ = cute.make_fragment(acc_shape_dQ, cutlass.Float32) acc_dQ.fill(0.0) sm80_utils.gemm( mma_params.thr_mma_dq, acc_dQ, mma_params.tdQrdS, mma_params.tdQrK, smem_copy_params.tdQsdS, smem_copy_params.tdQsKt, smem_copy_params.smem_thr_copy_dS, smem_copy_params.smem_thr_copy_Kt, swap_AB=self.dQ_swapAB, hook_fn=hook_fn ) # ((1, 1), num_elements) acc_dQ_atomic = gmem_copy_params.gmem_thr_copy_dQaccum.retile(acc_dQ) tdQgdQaccum_atomic = gmem_copy_params.tdQgdQaccum[None, None, m_block] assert cute.size(acc_dQ_atomic) == cute.size(tdQgdQaccum_atomic) for i in cutlass.range(cute.size(acc_dQ_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dQ_atomic[i], utils.elem_pointer(tdQgdQaccum_atomic, i)) # utils.atomic_add_fp32(acc_dQ[i], tdQgdQaccum_atomic.iterator + i * tdQgdQaccum_atomic.stride[1]) # if cute.arch.thread_idx()[0] == 64 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dQ) # If num_stages_Q == 1, we want to do Mma_dK first so we can start loading Q for the next iteration if cutlass.const_expr(self.num_stages_Q > 1): dQ_mma(load_dO_next) # MMA dK if cutlass.const_expr(self.Mma_dKV_is_RS): tdKrdS = layout_utils.reshape_acc_to_frgA(rdS) else: tdKrdS = mma_params.tdKrdS sm80_utils.gemm( mma_params.thr_mma_dkv, mma_params.acc_dK, tdKrdS, mma_params.tdKrQ, smem_copy_params.tdKsdSt, smem_copy_params.tdKsQt[None, None, None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], smem_copy_params.smem_thr_copy_PdSt, smem_copy_params.smem_thr_copy_QdOt, A_in_regs=self.Mma_dKV_is_RS, swap_AB=self.dKV_swapAB, hook_fn=load_dO_next if cutlass.const_expr(self.num_stages_Q == 1) else None, ) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(mma_params.acc_dK) if cutlass.const_expr(self.num_stages_Q == 1): cute.arch.barrier() dQ_mma(load_Q_next) @cute.jit def epilogue( self, acc_dK: cute.Tensor, acc_dV: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, sdK: cute.Tensor, sdV: cute.Tensor, gmem_tiled_copy_dK: cute.TiledCopy, gmem_tiled_copy_dV: cute.TiledCopy, tiled_mma: cute.TiledMma, tidx: cutlass.Int32, n_block: cutlass.Int32, num_head: cutlass.Int32, batch_size: cutlass.Int32, seqlen: SeqlenInfoQK, d_head: cutlass.Int32, d_head_v: cutlass.Int32 ): rdV = cute.make_fragment_like(acc_dV, self.dtype) rdV.store(acc_dV.load().to(self.dtype)) rdK = cute.make_fragment_like(acc_dK, self.dtype) rdK.store(acc_dK.load().to(self.dtype)) gmem_thr_copy_dK = gmem_tiled_copy_dK.get_slice(tidx) gmem_thr_copy_dV = gmem_tiled_copy_dV.get_slice(tidx) batch_idx = batch_size head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head if cutlass.const_expr(self.qhead_per_kvhead == 1): # Make sure all threads have finished reading K and V, otherwise we get racy dQ # because smem_q could be changed. cute.arch.barrier() # smem copy atom for dKV smem_copy_atom_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=2 * self.dtype.width ) smem_thr_copy_dKV = cute.make_tiled_copy_C(smem_copy_atom_dKV, tiled_mma).get_slice(tidx) taccdVrdV = smem_thr_copy_dKV.retile(rdV) taccdKrdK = smem_thr_copy_dKV.retile(rdK) taccdVsdV = smem_thr_copy_dKV.partition_D(sdV) taccdKsdK = smem_thr_copy_dKV.partition_D(sdK) # copy acc O from rmem to smem with the smem copy atom cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV) cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK) if cutlass.const_expr(not seqlen.has_cu_seqlens_k): mdK_cur, mdV_cur = [t[batch_idx, None, head_idx_kv, None] for t in (mdK, mdV)] else: mdK_cur, mdV_cur = [cute.domain_offset((seqlen.offset_k, 0), t[None, head_idx_kv, None]) for t in (mdK, mdV)] blkdK_shape = (self.n_block_size, self.head_dim_padded) blkdV_shape = (self.n_block_size, self.head_dim_v_padded) gdK = cute.local_tile(mdK_cur, blkdK_shape, (n_block, 0)) gdV = cute.local_tile(mdV_cur, blkdV_shape, (n_block, 0)) tdKsdK = gmem_thr_copy_dK.partition_S(sdK) tdKgdK = gmem_thr_copy_dK.partition_D(gdK) tdVsdV = gmem_thr_copy_dV.partition_S(sdV) tdVgdV = gmem_thr_copy_dV.partition_D(gdV) tdKrdK = cute.make_fragment_like(tdKgdK, self.dtype) tdVrdV = cute.make_fragment_like(tdVgdV, self.dtype) # sync before all smem stores are done. cute.arch.barrier() # load acc dK and dV from smem to rmem for wider vectorization # Need to check OOB when reading from smem if kBlockN isn't evenly tiled # TODO cute.autovec_copy(tdKsdK, tdKrdK) cute.autovec_copy(tdVsdV, tdVrdV) cdK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) tdKcdK = gmem_thr_copy_dK.partition_S(cdK) t0dKcdK = gmem_tiled_copy_dK.get_slice(0).partition_S(cdK) if cutlass.const_expr(self.head_dim_padded == self.head_dim_v_padded): tdVcdV = tdKcdK t0dVcdV = t0dKcdK else: cdV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) tdVcdV = gmem_thr_copy_dV.partition_S(cdV) t0dVcdV = gmem_tiled_copy_dV.get_slice(0).partition_S(cdV) tdKpdK = utils.predicate_k(tdKcdK, limit=d_head) if cutlass.const_expr(self.same_hdim_kv): tdVpdV = tdKpdK else: tdVpdV = utils.predicate_k(tdVcdV, limit=d_head_v) # copy acc dK and acc_dV from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tdKrdK.shape[1])): if t0dKcdK[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdKcdK[0][0]: cute.copy( gmem_tiled_copy_dK, tdKrdK[None, rest_m, None], tdKgdK[None, rest_m, None], pred=tdKpdK[None, rest_m, None] if cutlass.const_expr(self.check_hdim_oob) else None, ) for rest_m in cutlass.range_constexpr(cute.size(tdVrdV.shape[1])): if t0dVcdV[0, rest_m, 0][0] < seqlen.seqlen_k - n_block * self.n_block_size - tdVcdV[0][0]: cute.copy( gmem_tiled_copy_dV, tdVrdV[None, rest_m, None], tdVgdV[None, rest_m, None], pred=tdVpdV[None, rest_m, None] if cutlass.const_expr(self.check_hdim_v_oob) else None, ) else: # qhead_per_kvhead > 1, do atomic add # For Sm90, we need to sync to avoid racy writes to smem_q # For Sm80, we don't need to sync since we're not touching smem head_idx_kv = num_head // self.qhead_per_kvhead if cutlass.const_expr(not self.pack_gqa) else num_head if cutlass.const_expr(not seqlen.has_cu_seqlens_k): mdK_cur, mdV_cur = [t[batch_idx, head_idx_kv, None] for t in (mdK, mdV)] else: padded_offset_k = seqlen.offset_k + batch_idx * self.n_block_size mdK_cur = cute.domain_offset((padded_offset_k * self.head_dim_padded,), mdK[head_idx_kv, None]) mdV_cur = cute.domain_offset((padded_offset_k * self.head_dim_v_padded,), mdV[head_idx_kv, None]) gdV = cute.local_tile(mdV_cur, (self.n_block_size * self.head_dim_v_padded,), (n_block,)) gdK = cute.local_tile(mdK_cur, (self.n_block_size * self.head_dim_padded,), (n_block,)) tdVgdVaccum = gmem_thr_copy_dV.partition_S(gdV) tdKgdKaccum = gmem_thr_copy_dK.partition_S(gdK) acc_dV_atomic = gmem_thr_copy_dV.retile(acc_dV) acc_dK_atomic = gmem_thr_copy_dK.retile(acc_dK) assert cute.size(acc_dV_atomic) == cute.size(tdVgdVaccum) assert cute.size(acc_dK_atomic) == cute.size(tdKgdKaccum) for i in cutlass.range(cute.size(acc_dV_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dV_atomic[i], utils.elem_pointer(tdVgdVaccum, i)) for i in cutlass.range(cute.size(acc_dK_atomic), unroll_full=True): utils.atomic_add_fp32(acc_dK_atomic[i], utils.elem_pointer(tdKgdKaccum, i)) @cute.jit def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constexpr): return pipeline_index + 1 if pipeline_index < num_stages - 1 else 0 @cute.jit def load_K( self, gmem_thr_copy: cute.TiledCopy, tKgK: cute.Tensor, tKsK: cute.Tensor, block: cutlass.Int32, seqlen: cutlass.Int32, headdim: cutlass.Int32, ): cK = cute.make_identity_tensor((self.n_block_size, self.head_dim_padded)) tKcK = gmem_thr_copy.partition_S(cK) t0KcK = gmem_thr_copy.get_slice(0).partition_S(cK) tKpK = utils.predicate_k(tKcK, limit=headdim) for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_k or n < cute.size(tKsK.shape[1]) - 1 or tKcK[0, n, 0][0] < self.n_block_size: # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. predicate_n = t0KcK[0, n, 0][0] < seqlen - block * self.n_block_size - tKcK[0][0] predicate = cute.make_fragment_like(tKpK[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tKpK[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tKgK[None, n, None], tKsK[None, n, None], pred=predicate, ) # We need to clear the sK smem tiles since we'll use sKt for mma_dq @cute.jit def load_V( self, gmem_thr_copy: cute.TiledCopy, tVgV: cute.Tensor, tVsV: cute.Tensor, block: cutlass.Int32, seqlen: cutlass.Int32, headdim: cutlass.Int32, ): cV = cute.make_identity_tensor((self.n_block_size, self.head_dim_v_padded)) tVcV = gmem_thr_copy.partition_S(cV) t0VcV = gmem_thr_copy.get_slice(0).partition_S(cV) tVpV = utils.predicate_k(tVcV, limit=headdim) for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if self.is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.n_block_size: # Instead of using tVcV, we using t0VcV and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0VcV are known at compile time. predicate_n = t0VcV[0, n, 0][0] < seqlen - block * self.n_block_size - tVcV[0][0] predicate = cute.make_fragment_like(tVpV[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tVpV[i, n, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_n cute.copy( gmem_thr_copy, tVgV[None, n, None], tVsV[None, n, None], pred=predicate, ) @cute.jit def load_Q_LSE( self, gmem_tiled_copy_Q: cute.TiledCopy, gmem_tiled_copy_LSE: cute.TiledCopy, tQgQ: cute.Tensor, tQsQ: cute.Tensor, tQcQ: cute.Tensor, t0QcQ: cute.Tensor, tQpQ: cute.Tensor, tLSEgLSE: cute.Tensor, tLSEsLSE: cute.Tensor, tLSEcLSE: cute.Tensor, block: cutlass.Int32, smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_q or m < cute.size(tQsQ.shape[1]) - 1 or tQcQ[0, m, 0][0] < self.m_block_size: # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. predicate_m = t0QcQ[0, m, 0][0] < seqlen - block * self.m_block_size - tQcQ[0][0] predicate = cute.make_fragment_like(tQpQ[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tQpQ[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_Q, tQgQ[None, m, None, block], tQsQ[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q) > 1 else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. for m in cutlass.range_constexpr(cute.size(tLSEsLSE.shape[1])): if tLSEcLSE[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_LSE, tLSEgLSE[None, m, block], tLSEsLSE[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], ) @cute.jit def load_dO_dPsum( self, gmem_tiled_copy_dO: cute.TiledCopy, gmem_tiled_copy_dPsum: cute.TiledCopy, tdOgdO: cute.Tensor, tdOsdO: cute.Tensor, tdOcdO: cute.Tensor, t0dOcdO: cute.Tensor, tdOpdO: cute.Tensor, tdPsumgdPsum: cute.Tensor, tdPsumsdPsum: cute.Tensor, tdPsumcdPsum: cute.Tensor, block: cutlass.Int32, smem_pipe_write_q: cutlass.Int32, seqlen: cutlass.Int32, ): for m in cutlass.range_constexpr(cute.size(tdOsdO.shape[1])): # If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if self.is_even_m_smem_do or m < cute.size(tdOsdO.shape[1]) - 1 or tdOcdO[0, m, 0][0] < self.m_block_size: # Instead of using tdOcdO, we using t0dOcdO and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0dOcdO are known at compile time. predicate_m = t0dOcdO[0, m, 0][0] < seqlen - block * self.m_block_size - tdOcdO[0][0] predicate = cute.make_fragment_like(tdOpdO[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = (tdOpdO[i, m, k] if cutlass.const_expr(self.check_hdim_oob) else True) and predicate_m cute.copy( gmem_tiled_copy_dO, tdOgdO[None, m, None, block], tdOsdO[None, m, None, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], pred=predicate, ) # We need to clear the sQ smem tiles since we'll use sQt for mma_dK # We made sure LSE length is padded so we read `kBlockM` elements so that all # elements in sLSE are filled. Without this we might have uninitialized sLSE values. for m in cutlass.range_constexpr(cute.size(tdPsumgdPsum.shape[1])): if tdPsumcdPsum[0, m][0] < self.m_block_size: cute.copy( gmem_tiled_copy_dPsum, tdPsumgdPsum[None, m, block], tdPsumsdPsum[None, m, smem_pipe_write_q if cutlass.const_expr(self.num_stages_dO > 1) else 0], ) ================================================ FILE: flash_attn/cute/flash_bwd_postprocess.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_postprocess_kernel.h # from Cutlass C++ to Cute-DSL. import math from typing import Callable, Optional, Type import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass.cute.nvgpu import cpasync, warp, warpgroup from cutlass import Float32, const_expr from cutlass.utils import LayoutEnum from quack import copy_utils from quack import layout_utils from quack import sm90_utils from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, ) class FlashAttentionBackwardPostprocess: def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, arch: int, tile_m: int = 128, num_threads: int = 256, AtomLayoutMdQ: int = 1, dQ_swapAB: bool = False, use_2cta_instrs: bool = False, cluster_size: int = 1, # for varlen offsets ): """ :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int """ self.dtype = dtype self.tile_m = tile_m assert arch // 10 in [8, 9, 10, 11, 12], ( "Only Ampere (8.x), Hopper (9.x), and Blackwell (10.x, 11.x, 12.x) are supported" ) self.arch = arch # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) self.check_hdim_oob = head_dim != self.tile_hdim self.num_threads = num_threads self.AtomLayoutMdQ = AtomLayoutMdQ self.dQ_swapAB = dQ_swapAB self.use_2cta_instrs = use_2cta_instrs and arch // 10 == 10 and head_dim != 64 self.cluster_size = cluster_size @staticmethod def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int :return: True if the kernel can be implemented, False otherwise :rtype: bool """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if num_threads % 32 != 0: return False return True def _get_tiled_mma(self): if const_expr(self.arch // 10 in [8, 12]): num_mma_warps = self.num_threads // 32 atom_layout_dQ = ( (self.AtomLayoutMdQ, num_mma_warps // self.AtomLayoutMdQ, 1) if const_expr(not self.dQ_swapAB) else (num_mma_warps // self.AtomLayoutMdQ, self.AtomLayoutMdQ, 1) ) tiled_mma = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), atom_layout_dQ, permutation_mnk=(atom_layout_dQ[0] * 16, atom_layout_dQ[1] * 16, 16), ) elif const_expr(self.arch // 10 == 9): num_wg_mma = self.num_threads // 128 atom_layout_dQ = (self.AtomLayoutMdQ, num_wg_mma // self.AtomLayoutMdQ) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, # These don't matter, we only care about the accum warpgroup.OperandMajorMode.K, Float32, atom_layout_mnk=(atom_layout_dQ if not self.dQ_swapAB else atom_layout_dQ[::-1]) + (1,), tiler_mn=tiler_mn_dQ if not self.dQ_swapAB else tiler_mn_dQ[::-1], ) else: cta_group = tcgen05.CtaGroup.ONE tiled_mma = sm100_utils_basic.make_trivial_tiled_mma( self.dtype, tcgen05.OperandMajorMode.MN, # dS_major_mode tcgen05.OperandMajorMode.MN, # Kt_major_mode Float32, cta_group, (self.tile_m, self.tile_hdim), ) if const_expr(self.arch // 10 in [8, 9, 12]): assert self.num_threads == tiled_mma.size return tiled_mma def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies universal_copy_bits = 128 async_copy_elems_accum = universal_copy_bits // Float32.width atom_async_copy_accum = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), Float32, num_bits_per_copy=universal_copy_bits, ) # We don't do bound checking for the gmem -> smem load so we just assert here. assert (self.tile_m * self.tile_hdim // async_copy_elems_accum) % self.num_threads == 0 self.g2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( atom_async_copy_accum, cute.make_layout(self.num_threads), cute.make_layout(async_copy_elems_accum), ) num_s2r_copy_elems = 1 if const_expr(self.arch // 10 in [8, 12]) else 4 if const_expr(self.arch // 10 in [8, 12]): self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_s2r_copy_elems ) self.sdQaccum_layout = cute.make_layout(self.tile_m * self.tile_hdim) elif const_expr(self.arch // 10 == 9): num_threads_per_warp_group = 128 num_wg_mma = self.num_threads // 128 self.s2r_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), cute.make_layout((num_threads_per_warp_group, num_wg_mma)), # thr_layout cute.make_layout(128 // Float32.width), # val_layout ) self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.tile_hdim // num_wg_mma, num_wg_mma) ) else: self.dQ_reduce_ncol = 32 dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol assert self.num_threads == 128 # TODO: currently hard-coded self.s2r_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_s2r_copy_elems ) self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.tile_hdim // dQaccum_reduce_stage, dQaccum_reduce_stage) ) num_copy_elems = 128 // self.dtype.width threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: dQ # /////////////////////////////////////////////////////////////////////////////// # We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, # then setting kBlockKSmem to 32 will cause "Static shape_div failure". # We want to treat it as 64 x 48, so kBlockKSmem should be 16. mma_shape_n = self.tiled_mma.get_tile_size(1) if const_expr(self.arch // 10 in [8, 12]): sdQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, mma_shape_n) self.sdQ_layout = cute.tile_to_shape( sdQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1) ) elif const_expr(self.arch // 10 == 9): wg_d_dQ = num_wg_mma // self.AtomLayoutMdQ self.sdQ_layout = sm90_utils.make_smem_layout( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), major_mode_size=self.tile_hdim // wg_d_dQ, ) else: # TODO: this is hard-coded for hdim 128 self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_hdim), 1 ) @cute.jit def __call__( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, scale: cutlass.Float32, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 if const_expr(mdQ.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mdQaccum is not None): if const_expr(mdQaccum.element_type not in [cutlass.Float32]): raise TypeError("dQaccum tensor must be Float32") mdQaccum, mdQ = [assume_tensor_aligned(t) for t in (mdQaccum, mdQ)] self.tiled_mma = self._get_tiled_mma() self._setup_attributes() smem_size = max( cute.size_in_bytes(cutlass.Float32, self.sdQaccum_layout), cute.size_in_bytes(self.dtype, self.sdQ_layout), ) if const_expr(mCuSeqlensQ is not None): TileScheduler = SingleTileVarlenScheduler num_head = mdQ.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 num_block = cute.ceil_div(mdQ.shape[0], self.tile_m) else: TileScheduler = SingleTileScheduler num_head = mdQ.shape[2] num_batch = mdQ.shape[0] num_block = cute.ceil_div(mdQ.shape[1], self.tile_m) tile_sched_args = TileSchedulerArguments( num_block=num_block, num_head=num_head, num_batch=num_batch, num_splits=1, seqlen_k=0, headdim=mdQ.shape[2], headdim_v=0, total_q=mdQ.shape[0], tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) # grid_dim: (m_block, num_head, batch_size) self.kernel( mdQaccum, mdQ, mCuSeqlensQ, mSeqUsedQ, scale, self.tiled_mma, self.dQ_swapAB, self.sdQaccum_layout, self.sdQ_layout, self.g2s_tiled_copy_dQaccum, self.s2r_tiled_copy_dQaccum, self.gmem_tiled_copy_dQ, tile_sched_params, TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], smem=smem_size, stream=stream, ) @cute.kernel def kernel( self, mdQaccum: cute.Tensor, mdQ: cute.Tensor, mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], scale: cutlass.Float32, tiled_mma: cute.TiledMma, dQ_swapAB: cutlass.Constexpr, sdQaccum_layout: cute.Layout, sdQ_layout: cute.ComposedLayout, g2s_tiled_copy_dQaccum: cute.TiledCopy, s2r_tiled_copy_dQaccum: cute.TiledCopy, gmem_tiled_copy_dQ: cute.TiledCopy, tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], ): # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() sdQaccum = smem.allocate_tensor(cutlass.Float32, sdQaccum_layout, byte_alignment=1024) sdQaccum_flat = cute.make_tensor(sdQaccum.iterator, cute.make_layout(cute.size(sdQaccum))) if const_expr(self.arch // 10 in [8, 9, 12]): sdQ = cute.make_tensor(cute.recast_ptr(sdQaccum.iterator, dtype=self.dtype), sdQ_layout) else: # extra stage dimension sdQ = cute.make_tensor( cute.recast_ptr(sdQaccum.iterator, sdQ_layout.inner, dtype=self.dtype), sdQ_layout.outer, )[None, None, 0] sdQt = layout_utils.transpose_view(sdQ) # Thread index, block index tidx, _, _ = cute.arch.thread_idx() tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() m_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// seqlen = SeqlenInfoQK.create( batch_idx, mdQ.shape[1], 0, mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=None, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=None, tile_m=self.tile_m * self.cluster_size, ) if const_expr(not seqlen.has_cu_seqlens_q): mdQ_cur = mdQ[batch_idx, None, head_idx, None] mdQaccum_cur = mdQaccum[batch_idx, head_idx, None] head_dim = mdQ.shape[3] else: padded_offset_q = seqlen.padded_offset_q mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None]) mdQaccum_cur = cute.domain_offset( (padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None] ) head_dim = mdQ.shape[2] # HACK: Compiler doesn't seem to recognize that padding # by padded_offset_q * self.tile_hdim keeps alignment # since statically divisible by 4 mdQaccum_cur_ptr = cute.make_ptr( dtype=mdQaccum_cur.element_type, value=mdQaccum_cur.iterator.toint(), mem_space=mdQaccum_cur.iterator.memspace, assumed_align=mdQaccum.iterator.alignment, ) mdQaccum_cur = cute.make_tensor(mdQaccum_cur_ptr, mdQaccum_cur.layout) gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,)) gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) seqlen_q = seqlen.seqlen_q seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) if const_expr(self.arch // 10 == 10 and self.use_2cta_instrs): # 2-CTA: remap dQaccum layout into TMEM view before writing sdQ num_reduce_threads = self.num_threads thr_mma_dsk = tiled_mma.get_slice(tidx) dQacc_shape = thr_mma_dsk.partition_shape_C((self.tile_m, self.tile_hdim)) tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32 ) tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) thr_tmem_ld = tiled_tmem_ld.get_slice(tidx) cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tdQcdQ = thr_mma_dsk.partition_C(cdQ) tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout) tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor) tiled_copy_accum = s2r_tiled_copy_dQaccum g2s_thr_copy = tiled_copy_accum.get_slice(tidx) # S -> R tdQrdQ_fp32 = cute.make_fragment(tdQrdQ.shape, cutlass.Float32) tdQrdQ_s2r = cute.make_tensor(tdQrdQ_fp32.iterator, tdQrdQ_fp32.shape) smem_copy_atom = sm100_utils_basic.get_smem_store_op( LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld ) r2s_tiled_copy = cute.make_tiled_copy( smem_copy_atom, layout_tv=tiled_tmem_ld.layout_dst_tv_tiled, tiler_mn=tiled_tmem_ld.tiler_mn, ) tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ)) tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype) num_stages = cute.size(tdQrdQ_fp32, mode=[1]) stage_stride = self.dQ_reduce_ncol row_groups = 2 assert num_stages % row_groups == 0 assert num_reduce_threads % row_groups == 0 stage_groups = num_stages // row_groups threads_per_row_group = num_reduce_threads // row_groups stage_loads = tuple((row_group, row_group) for row_group in range(row_groups)) stage_iters = tuple( (row_group, row_group * threads_per_row_group) for row_group in range(row_groups) ) s2r_lane = tidx % threads_per_row_group s2r_buf = tidx // threads_per_row_group gdQaccum_layout_g2s = cute.make_layout( shape=(self.tile_m * self.dQ_reduce_ncol, 1), stride=(1, 0) ) sdQaccum_g2s = g2s_thr_copy.partition_D(sdQaccum) # G -> S for stage_group in cutlass.range_constexpr(stage_groups): for stage_offset, smem_buf in stage_loads: stage_idx = stage_group + stage_offset * stage_groups gdQaccum_stage = cute.local_tile( gdQaccum, (self.tile_m * self.dQ_reduce_ncol,), (stage_idx,), ) gdQaccum_stage_g2s = cute.make_tensor( gdQaccum_stage.iterator, gdQaccum_layout_g2s, ) tdQgdQ = g2s_thr_copy.partition_S(gdQaccum_stage_g2s) cute.copy( g2s_thr_copy, tdQgdQ[None, None, 0], sdQaccum_g2s[None, None, smem_buf], ) cute.arch.fence_view_async_shared() cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads) # S -> R for stage_offset, lane_offset in stage_iters: stage_idx = stage_group + stage_offset * stage_groups s2r_src_tidx = s2r_lane + lane_offset s2r_thr_copy = tiled_copy_accum.get_slice(s2r_src_tidx) sdQaccum_src = s2r_thr_copy.partition_S(sdQaccum)[None, None, s2r_buf] tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage_idx, None, None] tdQrdQ_r2s_cpy = cute.make_tensor( tdQrdQ_s2r_cpy.iterator, cute.make_layout(sdQaccum_src.shape) ) cute.copy(s2r_thr_copy, sdQaccum_src, tdQrdQ_r2s_cpy) cute.arch.fence_view_async_shared() cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads) # R -> S stage_lo = stage_idx % stage_stride stage_hi = stage_idx // stage_stride tdQrdQ_r2s_cpy = cute.make_tensor( cute.recast_ptr(tdQrdQ_r2s_cpy.iterator), tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].shape, ) dQ_vec = tdQrdQ_r2s_cpy.load() * scale tdQrdQ_r2s[((None, 0), (stage_lo, stage_hi), 0, 0)].store( dQ_vec.to(self.dtype) ) # R -> S cute.copy( r2s_tiled_copy, tdQrdQ_r2s[None, None, None, 0], tdQsdQ_r2s[None, None, None, 0], ) cute.arch.fence_view_async_shared() cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads) else: # Step 1: load dQaccum from gmem to smem g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_slice(tidx) tdQgdQaccum = g2s_thr_copy_dQaccum.partition_S(gdQaccum) tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum_flat) cute.copy(g2s_tiled_copy_dQaccum, tdQgdQaccum, tdQsdQaccumg2s) cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(0) cute.arch.barrier() # Step 2: load dQ from smem to rmem s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_slice(tidx) tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum) tile_shape = (self.tile_m, self.tile_hdim) acc = None tiled_copy_t2r = None if const_expr(self.arch // 10 in [8, 9, 12]): acc_shape = tiled_mma.partition_shape_C( tile_shape if const_expr(not dQ_swapAB) else tile_shape[::-1] ) acc = cute.make_fragment(acc_shape, cutlass.Float32) assert cute.size(acc) == cute.size(tdQsdQaccum) else: thr_mma = tiled_mma.get_slice(0) # 1-CTA dQacc_shape = tiled_mma.partition_shape_C((self.tile_m, self.tile_hdim)) tdQtdQ = tiled_mma.make_fragment_C(dQacc_shape) tdQcdQ = thr_mma.partition_C( cute.make_identity_tensor((self.tile_m, self.tile_hdim)) ) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32, ) tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape acc = cute.make_fragment(tdQrdQ_t2r_shape, Float32) tdQrdQaccum = cute.make_tensor(acc.iterator, cute.make_layout(tdQsdQaccum.shape)) cute.autovec_copy(tdQsdQaccum, tdQrdQaccum) # Convert tdQrdQaccum from fp32 to fp16/bf16 rdQ = cute.make_fragment_like(acc, self.dtype) rdQ.store((acc.load() * scale).to(self.dtype)) # Step 3: Copy dQ from register to smem cute.arch.barrier() # make sure all threads have finished loading dQaccum if const_expr(self.arch // 10 in [8, 9, 12]): copy_atom_r2s_dQ = utils.get_smem_store_atom( self.arch, self.dtype, transpose=self.dQ_swapAB ) tiled_copy_r2s_dQ = cute.make_tiled_copy_C(copy_atom_r2s_dQ, tiled_mma) else: # copy_atom_r2s_dQ = sm100_utils_basic.get_smem_store_op( # LayoutEnum.ROW_MAJOR, self.dtype, Float32, tiled_copy_t2r, # ) # tiled_copy_r2s_dQ = cute.make_tiled_copy_D(copy_atom_r2s_dQ, tiled_copy_t2r) thr_layout_r2s_dQ = cute.make_layout((self.num_threads, 1)) # 128 threads val_layout_r2s_dQ = cute.make_layout((1, 128 // self.dtype.width)) copy_atom_r2s_dQ = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=128, ) tiled_copy_r2s_dQ = cute.make_tiled_copy_tv( copy_atom_r2s_dQ, thr_layout_r2s_dQ, val_layout_r2s_dQ ) thr_copy_r2s_dQ = tiled_copy_r2s_dQ.get_slice(tidx) cdQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) if const_expr(self.arch // 10 in [8, 9, 12]): taccdQrdQ = thr_copy_r2s_dQ.retile(rdQ) else: taccdQcdQ_shape = thr_copy_r2s_dQ.partition_S(cdQ).shape taccdQrdQ = cute.make_tensor(rdQ.iterator, taccdQcdQ_shape) taccdQsdQ = thr_copy_r2s_dQ.partition_D( sdQ if const_expr(not self.dQ_swapAB) else sdQt ) cute.copy(thr_copy_r2s_dQ, taccdQrdQ, taccdQsdQ) # Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem cute.arch.barrier() # make sure all smem stores are done gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_slice(tidx) tdQgdQ = gmem_thr_copy_dQ.partition_S(gdQ) tdQsdQ = gmem_thr_copy_dQ.partition_D(sdQ) tdQrdQ = cute.make_fragment_like(tdQsdQ, self.dtype) # TODO: check OOB when reading from smem if kBlockM isn't evenly tiled cute.autovec_copy(tdQsdQ, tdQrdQ) # Step 5: Copy dQ from register to gmem tdQcdQ = gmem_thr_copy_dQ.partition_S(cdQ) tdQpdQ = utils.predicate_k(tdQcdQ, limit=head_dim) for rest_m in cutlass.range(cute.size(tdQrdQ.shape[1]), unroll_full=True): if tdQcdQ[0, rest_m, 0][0] < seqlen_q - m_block * self.tile_m: cute.copy( gmem_tiled_copy_dQ, tdQrdQ[None, rest_m, None], tdQgdQ[None, rest_m, None], pred=tdQpdQ[None, rest_m, None], ) ================================================ FILE: flash_attn/cute/flash_bwd_preprocess.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_bwd_preprocess_kernel.h # from Cutlass C++ to Cute-DSL. # # Computes D_i = (dO_i * O_i).sum(dim=-1), optionally adjusted for LSE gradient: # D'_i = D_i - dLSE_i # This works because in the backward pass: # dS_ij = P_ij * (dP_ij - D_i) [standard] # When LSE is differentiable, d(loss)/d(S_ij) gets an extra term dLSE_i * P_ij # (since d(LSE_i)/d(S_ij) = P_ij), giving: # dS_ij = P_ij * (dP_ij - D_i) + dLSE_i * P_ij # = P_ij * (dP_ij - (D_i - dLSE_i)) # So the main backward kernel is unchanged; we just replace D with D' = D - dLSE here. import math import operator from functools import partial from typing import Callable, Type, Optional import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Float32, const_expr from cutlass.cutlass_dsl import Arch, BaseDSL from quack import copy_utils, layout_utils from flash_attn.cute import utils from flash_attn.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, ) class FlashAttentionBackwardPreprocess: def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: int, tile_m: int = 128, num_threads: int = 256, ): """ All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension should be a multiple of 8. :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int :param num_threads: number of threads :type num_threads: int """ self.use_pdl = BaseDSL._get_dsl().get_arch_enum() >= Arch.sm_90a self.dtype = dtype self.tile_m = tile_m # padding head_dim to a multiple of 32 as k_block_size hdim_multiple_of = 32 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.num_threads = num_threads @staticmethod def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int :param num_threads: number of threads :type num_threads: int :return: True if the kernel can be implemented, False otherwise :rtype: bool """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if num_threads % 32 != 0: return False if num_threads < tile_m: # For multiplying lse with log2 return False return True def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies # We want kBlockKGmem to be a power of 2 so that when we do the summing, # it's just between threads in the same warp gmem_k_block_size = ( 128 if self.head_dim_v_padded % 128 == 0 else ( 64 if self.head_dim_v_padded % 64 == 0 else (32 if self.head_dim_v_padded % 32 == 0 else 16) ) ) num_copy_elems = 128 // self.dtype.width threads_per_row = gmem_k_block_size // num_copy_elems self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) universal_copy_bits = 128 num_copy_elems_dQaccum = universal_copy_bits // Float32.width assert ( self.tile_m * self.head_dim_padded // num_copy_elems_dQaccum ) % self.num_threads == 0 self.gmem_tiled_copy_dQaccum = copy_utils.tiled_copy_1d( Float32, self.num_threads, num_copy_elems_dQaccum ) @cute.jit def __call__( self, mO: cute.Tensor, # (batch, seqlen, nheads, head_dim_v) or (total_q, nheads, head_dim_v) mdO: cute.Tensor, # same shape as mO mPdPsum: cute.Tensor, # (batch, nheads, seqlen_padded) or (nheads, total_q_padded) mLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) mLSElog2: Optional[cute.Tensor], # same shape as mPdPsum # (batch, nheads, seqlen_padded * head_dim_v) or (nheads, total_q_padded * head_dim_v) mdQaccum: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,) mSeqUsedQ: Optional[cute.Tensor], # (batch,) mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q) # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # Get the data type and check if it is fp16 or bf16 if const_expr(not (mO.element_type == mdO.element_type)): raise TypeError("All tensors must have the same data type") if const_expr(mO.element_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mPdPsum.element_type not in [Float32]): raise TypeError("PdPsum tensor must be Float32") if const_expr(mdQaccum is not None): if const_expr(mdQaccum.element_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") if const_expr(mLSE is not None): assert mLSElog2 is not None, "If mLSE is provided, mLSElog2 must also be provided" if const_expr(mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") if const_expr(mLSElog2.element_type not in [Float32]): raise TypeError("LSElog2 tensor must be Float32") if const_expr(mdLSE is not None): if const_expr(mdLSE.element_type not in [Float32]): raise TypeError("dLSE tensor must be Float32") self._setup_attributes() # (batch, nheads, seqlen) -> (seqlen, nheads, batch) or (total_q, nheads) -> (nheads, total_q) transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mPdPsum = layout_utils.select(mPdPsum, transpose) if const_expr(mLSE is not None): mLSE = layout_utils.select(mLSE, transpose) mLSElog2 = layout_utils.select(mLSElog2, transpose) if const_expr(mdLSE is not None): mdLSE = layout_utils.select(mdLSE, transpose) if const_expr(mdQaccum is not None): mdQaccum = layout_utils.select(mdQaccum, transpose) if const_expr(mCuSeqlensQ is not None): TileScheduler = SingleTileVarlenScheduler num_head = mO.shape[1] num_batch = mCuSeqlensQ.shape[0] - 1 else: TileScheduler = SingleTileScheduler num_head = mO.shape[2] num_batch = mO.shape[0] tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mO.shape[1], self.tile_m), num_head=num_head, num_batch=num_batch, num_splits=1, seqlen_k=0, headdim=0, headdim_v=mO.shape[2], total_q=mO.shape[0], tile_shape_mn=(self.tile_m, 1), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) self.kernel( mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSeqUsedQ, mdLSE, self.gmem_tiled_copy_O, self.gmem_tiled_copy_dQaccum, tile_sched_params, TileScheduler, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, use_pdl=self.use_pdl, ) @cute.kernel def kernel( self, mO: cute.Tensor, mdO: cute.Tensor, mPdPsum: cute.Tensor, mLSE: Optional[cute.Tensor], mLSElog2: Optional[cute.Tensor], mdQaccum: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mdLSE: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_dQaccum: cute.TiledCopy, tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() m_block, head_idx, batch_idx, _ = work_tile.tile_idx if work_tile.is_valid_tile: # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// seqlen = SeqlenInfo.create( batch_idx, mO.shape[1], mCuSeqlensQ, mSeqUsedQ, tile=self.tile_m ) mO_cur = seqlen.offset_batch(mO, batch_idx, dim=0)[None, head_idx, None] mdO_cur = seqlen.offset_batch(mdO, batch_idx, dim=0)[None, head_idx, None] mPdPsum_cur = seqlen.offset_batch(mPdPsum, batch_idx, dim=2, padded=True)[ None, head_idx ] headdim_v = mO_cur.shape[cute.rank(mO_cur) - 1] seqlen_q = seqlen.seqlen seqlen_q_rounded = cute.round_up(seqlen_q, self.tile_m) seqlen_limit = seqlen_q - m_block * self.tile_m lse = None if const_expr(mLSE is not None): mLSE_cur = seqlen.offset_batch(mLSE, batch_idx, dim=2)[None, head_idx] gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) lse = Float32.inf if tidx < seqlen_limit: lse = gLSE[tidx] blk_shape = (self.tile_m, self.head_dim_v_padded) gO = cute.local_tile(mO_cur, blk_shape, (m_block, 0)) gdO = cute.local_tile(mdO_cur, blk_shape, (m_block, 0)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) # (CPY_Atom, CPY_M, CPY_K) tOgO = gmem_thr_copy_O.partition_S(gO) tOgdO = gmem_thr_copy_O.partition_S(gdO) cO = cute.make_identity_tensor(blk_shape) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_thr_copy_O.get_slice(0).partition_S(cO) tOpO = None if const_expr(self.check_hdim_v_oob): tOpO = copy_utils.predicate_k(tOcO, limit=headdim_v) # Each copy will use the same predicate copy = partial(copy_utils.copy, pred=tOpO) tOrO = cute.make_rmem_tensor_like(tOgO) tOrdO = cute.make_rmem_tensor_like(tOgdO) if const_expr(self.check_hdim_v_oob): tOrO.fill(0.0) tOrdO.fill(0.0) assert tOgO.shape == tOgdO.shape for m in cutlass.range(cute.size(tOrO.shape[1]), unroll_full=True): # Instead of using tOcO, we using t0OcO and subtract the offset from the limit. # This is bc the entries of t0OcO are known at compile time. if t0OcO[0, m, 0][0] < seqlen_limit - tOcO[0][0]: copy(tOgO[None, m, None], tOrO[None, m, None]) copy(tOgdO[None, m, None], tOrdO[None, m, None]) # O and dO loads are done; signal that the next kernel can start. # Correctness is ensured by griddepcontrol_wait() in bwd_sm90 before it reads our outputs. if const_expr(self.use_pdl): cute.arch.griddepcontrol_launch_dependents() # Sum across the "k" dimension pdpsum = (tOrO.load().to(Float32) * tOrdO.load().to(Float32)).reduce( cute.ReductionOp.ADD, init_val=0.0, reduction_profile=(0, None, 1) ) threads_per_row = gmem_tiled_copy_O.layout_src_tv_tiled[0].shape[0] assert cute.arch.WARP_SIZE % threads_per_row == 0 pdpsum = utils.warp_reduce(pdpsum, operator.add, width=threads_per_row) PdP_sum = cute.make_rmem_tensor(cute.size(tOrO, mode=[1]), Float32) PdP_sum.store(pdpsum) # If dLSE is provided, compute D' = D - dLSE (see module docstring for derivation). gdLSE = None if const_expr(mdLSE is not None): mdLSE_cur = seqlen.offset_batch(mdLSE, batch_idx, dim=2)[None, head_idx] gdLSE = cute.local_tile(mdLSE_cur, (self.tile_m,), (m_block,)) # Write PdPsum from rmem -> gmem gPdPsum = cute.local_tile(mPdPsum_cur, (self.tile_m,), (m_block,)) # Only the thread corresponding to column 0 writes out the PdPsum to gmem if tOcO[0, 0, 0][1] == 0: for m in cutlass.range(cute.size(PdP_sum), unroll_full=True): row = tOcO[0, m, 0][0] PdPsum_val = 0.0 if row < seqlen_limit: PdPsum_val = PdP_sum[m] if const_expr(mdLSE is not None): PdPsum_val -= gdLSE[row] gPdPsum[row] = PdPsum_val # Clear dQaccum if const_expr(mdQaccum is not None): mdQaccum_cur = seqlen.offset_batch( mdQaccum, batch_idx, dim=2, padded=True, multiple=self.head_dim_padded )[None, head_idx] blkdQaccum_shape = (self.tile_m * self.head_dim_padded,) gdQaccum = cute.local_tile(mdQaccum_cur, blkdQaccum_shape, (m_block,)) gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_slice(tidx) tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum) zero = cute.make_rmem_tensor_like(tdQgdQaccum) zero.fill(0.0) cute.copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum) if const_expr(mLSE is not None): mLSElog2_cur = seqlen.offset_batch(mLSElog2, batch_idx, dim=2, padded=True)[ None, head_idx ] gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.tile_m: gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 ================================================ FILE: flash_attn/cute/flash_bwd_sm100.py ================================================ # Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao. import math from typing import Callable, Optional from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass.cute import FastDivmodDivisor from cutlass import Float32, Int32, Int64, const_expr from cutlass.utils import LayoutEnum from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass.pipeline import PipelineAsync import quack.activation from quack import layout_utils from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import copy_utils from flash_attn.cute import pipeline from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) from flash_attn.cute import barrier from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, produce_block_sparse_q_loads_bwd_sm100, ) class FlashAttentionBackwardSm100: arch = 100 def __init__( self, head_dim: int, head_dim_v: Optional[int] = None, is_causal: bool = False, is_local: bool = False, qhead_per_kvhead: cutlass.Constexpr[int] = 1, tile_m: int = 128, tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, cluster_size: int = 1, use_2cta_instrs: bool = False, score_mod: cutlass.Constexpr | None = None, score_mod_bwd: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, ): # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) self.check_hdim_oob = head_dim != self.tile_hdim self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.tile_m = tile_m self.tile_n = tile_n assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128) assert self.tile_hdimv <= 128 self.use_2cta_instrs = bool( use_2cta_instrs and cluster_size == 2 and score_mod is None and score_mod_bwd is None and mask_mod is None ) self.cta_group_size = 2 if self.use_2cta_instrs else 1 assert self.tile_hdim != 192 or self.use_2cta_instrs, "Must use 2CTA for hdim 192" # CTA tiler self.cta_tiler = (tile_n, tile_m, self.tile_hdim) # S = K @ Q.T self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim) # dP = V @ dO.T self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv) # dV = P.T @ dO self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m) # dK = dS.T @ Q self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m) # dQ = dS @ K # 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size). self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size) self.acc_dtype = Float32 assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported" self.cluster_shape_mn = (cluster_size, 1) self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = is_local self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.deterministic = deterministic # Score mod and mask mod support self.score_mod = score_mod self.score_mod_bwd = score_mod_bwd self.mask_mod = mask_mod self.has_aux_tensors = has_aux_tensors self.subtile_factor = subtile_factor # For score_mod, use vec_size=1 (like forward) to handle per-element indices if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 4 self.qk_acc_dtype = Float32 # Speed optimizations, does not affect correctness self.shuffle_LSE = False self.shuffle_dPsum = False # Generally slower to use store dS in smem for dK, and doesn't work for 2cta self.use_smem_dS_for_mma_dK = False self.reduce_warp_ids = (0, 1, 2, 3) self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11) self.mma_warp_id = 12 self.load_warp_id = 13 self.relay_warp_id = 14 self.empty_warp_id = 15 # 16 warps -> 512 threads self.threads_per_cta = cute.arch.WARP_SIZE * len( ( *self.reduce_warp_ids, *self.compute_warp_ids, self.mma_warp_id, self.load_warp_id, self.relay_warp_id, self.empty_warp_id, ) ) # NamedBarrier self.compute_sync_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwdSm100.Compute), num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE, ) # self.epilogue_sync_barrier = pipeline.NamedBarrier( # barrier_id=2, # num_threads=self.num_compute_warps * self.threads_per_warp, # ) self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwdSm100.dQaccReduce), num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE, ) # TMEM setup self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") # self.tmem_dK_offset = 0 # self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim # self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv # self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ # self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim) # self.tmem_P_offset = self.tmem_S_offset # overlap with S # self.tmem_total = self.tmem_S_offset + self.tile_n # assert self.tmem_total <= self.tmem_alloc_cols if self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128: assert self.tile_m == 128 assert self.tile_n == 128 self.tmem_dV_offset = 0 self.tmem_dK_offset = self.tmem_dV_offset + self.tile_hdimv self.tmem_S_offset = self.tmem_dK_offset + self.tile_hdim self.tmem_P_offset = self.tmem_S_offset # overlap with S self.tmem_dP_offset = 512 - self.tile_m self.tmem_dS_offset = self.tmem_dP_offset # overlaps with dP self.tmem_dQ_offset = 512 - self.tile_hdim // 2 else: self.tmem_S_offset = 0 self.tmem_P_offset = 0 # overlap with S self.tmem_dV_offset = self.tmem_S_offset + self.tile_n self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv self.tmem_dQ_offset = ( (self.tmem_S_offset + (self.tile_hdim // 2)) if self.use_2cta_instrs else self.tmem_dP_offset ) self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP if (not is_causal and not is_local) or deterministic: self.num_regs_reduce = 136 if self.use_2cta_instrs else 152 self.num_regs_compute = 136 self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8 self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load else: self.num_regs_reduce = 136 if self.use_2cta_instrs else 136 self.num_regs_compute = 136 if self.use_2cta_instrs else 144 self.num_regs_load = 104 if self.use_2cta_instrs else 96 - 8 self.num_regs_mma = 104 if self.use_2cta_instrs else self.num_regs_load self.num_regs_empty = 24 if const_expr(self.tile_hdim == 192): if not is_causal and not is_local: self.num_regs_reduce = 128 + 8 self.num_regs_compute = 128 + 8 self.num_regs_load = 128 - 24 self.num_regs_mma = self.num_regs_load else: self.num_regs_reduce = 128 + 8 self.num_regs_compute = 128 + 8 self.num_regs_load = 128 - 24 self.num_regs_mma = self.num_regs_load assert ( self.num_regs_reduce + self.num_regs_compute * 2 + max(self.num_regs_load, self.num_regs_mma) <= 512 ) self.buffer_align_bytes = 1024 def _setup_attributes(self): self.Q_stage = 1 if self.use_2cta_instrs else 2 self.dO_stage = 1 self.single_stage = 1 # LSE_stage = Q_stage and dPsum_stage = dO_stage self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma # todo: try 32/1 or 48/2 for 2cta d=192 dv=128 if self.use_2cta_instrs and self.tile_hdim == 192: self.dQ_reduce_ncol_t2r = 32 self.dQ_reduce_ncol = 24 if not self.is_causal else 32 self.sdQaccum_stage = 2 if not self.is_causal else 1 else: if self.use_2cta_instrs: self.dQ_reduce_ncol = 16 if self.deterministic else 8 self.sdQaccum_stage = 2 if self.deterministic else 4 self.dQ_reduce_ncol_t2r = 32 else: self.dQ_reduce_ncol = 32 self.sdQaccum_stage = 64 // self.dQ_reduce_ncol self.dQ_reduce_ncol_t2r = self.dQ_reduce_ncol assert (self.tile_hdim // self.cta_group_size) % self.dQ_reduce_ncol == 0 self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol self.dQaccum_reduce_stage_t2r = self.tile_hdim // self.dQ_reduce_ncol_t2r self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 # number of tma reduce adds for dKacc and dVacc epilogue (must divide hdim_per_wg) self.dK_reduce_ncol = math.gcd(32, self.tile_hdim // 2) # CTA group for MMA operations self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE def _get_tiled_mma(self): # S.T = K @ Q.T tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma( self.q_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, self.acc_dtype, self.cta_group, self.mma_tiler_kq[:2], ) # dP.T = V @ dO.T tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, self.acc_dtype, self.cta_group, self.mma_tiler_vdo[:2], ) # dV += P.T @ dO --> (K, MN) major tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # P_major_mode tcgen05.OperandMajorMode.MN, # dO_major_mode self.acc_dtype, self.cta_group, self.mma_tiler_pdo[:2], a_source=tcgen05.OperandSource.TMEM, ) # dK += dS.T @ Q if const_expr(self.use_smem_dS_for_mma_dK): mma_dK_a_src = tcgen05.OperandSource.SMEM else: mma_dK_a_src = tcgen05.OperandSource.TMEM tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma( self.do_dtype, tcgen05.OperandMajorMode.K, # dS_major_mode tcgen05.OperandMajorMode.MN, # Q_major_mode self.acc_dtype, self.cta_group, self.mma_tiler_dsq[:2], a_source=mma_dK_a_src, ) # dQ = dS @ K tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma( self.k_dtype, tcgen05.OperandMajorMode.MN, # dS_major_mode tcgen05.OperandMajorMode.MN, # Kt_major_mode self.acc_dtype, self.cta_group, self.mma_tiler_dsk[:2], ) return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _setup_smem_layout(self): # S.T = K @ Q.T sK_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_S, self.mma_tiler_kq, self.k_dtype, 1, ) self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) self.sQ_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_S, self.mma_tiler_kq, self.q_dtype, self.Q_stage, ) # dP.T = V @ dO.T sV_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dP, self.mma_tiler_vdo, self.v_dtype, 1, ) self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dP, self.mma_tiler_vdo, self.do_dtype, self.dO_stage, ) # dV += P.T @ dO tP_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dV, self.mma_tiler_pdo, self.do_dtype, 1, ) self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0)) self.sdO_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dV, self.mma_tiler_pdo, self.do_dtype, self.dO_stage, ) # dK += dS.T @ Q sdSt_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, 1, ) self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) tdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, self.mma_tiler_dsq, self.ds_dtype, 1, ) self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0)) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, self.mma_tiler_dsq, self.q_dtype, self.Q_stage, ) # dQ = dS @ K sdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dQ, self.mma_tiler_dsk, self.ds_dtype, 1, ) self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0)) sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, self.mma_tiler_dsk, self.k_dtype, 1, ) self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0)) self.sdS_xchg_layout = cute.make_layout(shape=(self.tile_n, self.tile_m // 2)) self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage) ) self.sLSE_layout = cute.make_layout( shape=(self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64)) ) self.sdPsum_layout = cute.make_layout( shape=(self.tile_m, self.dO_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) self.sdK_epi_tile = ( self.tile_n, math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] self.sdV_epi_tile = ( self.tile_n, math.gcd(128 // (self.dk_dtype.width // 8), self.tile_hdimv // 2), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] # headdim_64 gets 1 stage self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdK_epi_tile[1]) self.num_epi_stages_v = max(1, (self.tile_hdimv // 2) // self.sdV_epi_tile[1]) self.sdK_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages self.sdV_flat_epi_tile = self.tile_n * (self.tile_hdimv // 2) // self.num_epi_stages_v if const_expr(not self.dKV_postprocess): self.sdK_layout = sm100_utils_basic.make_smem_layout_epi( self.dk_dtype, LayoutEnum.ROW_MAJOR, self.sdK_epi_tile, 2, # num compute wgs ) self.sdV_layout = sm100_utils_basic.make_smem_layout_epi( self.dv_dtype, LayoutEnum.ROW_MAJOR, self.sdV_epi_tile, 2, # num compute wgs ) else: self.sdK_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) # self.dK_reduce_ncol same for dV self.sdV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) @cute.jit def __call__( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): blocksparse_tensors: Optional[BlockSparseTensors] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): self.q_dtype = mQ.element_type self.k_dtype = mK.element_type self.v_dtype = mV.element_type self.do_dtype = mdO.element_type self.lse_dtype = mLSE.element_type self.dpsum_dtype = mdPsum.element_type self.dqaccum_dtype = mdQaccum.element_type self.dk_dtype = mdK.element_type self.dv_dtype = mdV.element_type self.ds_dtype = self.q_dtype self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None) # self.use_tma_store = not self.qhead_per_kvhead == 1 self.dKV_postprocess = self.qhead_per_kvhead > 1 if const_expr(self.dKV_postprocess): assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA" mdQaccum, mdK, mdV = [assume_tensor_aligned(t) for t in (mdQaccum, mdK, mdV)] # (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n) QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mdO = [layout_utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [layout_utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)] # (b, n, s) --> (s, n, b) or (n, t) --> (t, n) LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE, mdPsum, mdQaccum = [ layout_utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] if const_expr(not self.dKV_postprocess): layout_dKV_transpose = KV_layout_transpose else: layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0] mdK, mdV = [layout_utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] # (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b) dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2] mdO = layout_utils.select(mdO, mode=dO_transpose) # Transposes for 2-CTA K/Q paths (Q follows Q seqlens, K follows K seqlens) transpose_sh_q = dO_transpose transpose_sh_k = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] # (b, n, block, stage) -> (block, stage, n, b) semaphore_transpose = [2, 3, 1, 0] if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=semaphore_transpose) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): assert mdK_semaphore is not None assert mdV_semaphore is not None mdK_semaphore, mdV_semaphore = [ layout_utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore) ] else: mdK_semaphore = None mdV_semaphore = None self._setup_attributes() ( self.tiled_mma_S, self.tiled_mma_dP, self.tiled_mma_dK, self.tiled_mma_dV, self.tiled_mma_dQ, ) = self._get_tiled_mma() self._setup_smem_layout() self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (self.tiled_mma_S.thr_id.shape,), ) self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.is_q_do_mcast = self.num_mcast_ctas_b > 1 if const_expr(not self.dKV_postprocess): self.mdK_layout_enum = LayoutEnum.from_tensor(mdK) self.mdV_layout_enum = LayoutEnum.from_tensor(mdV) dK_major_mode = self.mdK_layout_enum.mma_major_mode() dV_major_mode = self.mdV_layout_enum.mma_major_mode() if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdK is wrong") if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") if const_expr(self.use_tma_store and not self.dKV_postprocess): tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdK, cute.select(self.sdK_layout, mode=[0, 1]), self.sdK_epi_tile, 1, # no mcast ) tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom( tma_copy_op_dKV, mdV, cute.select(self.sdV_layout, mode=[0, 1]), self.sdV_epi_tile, 1, # no mcast ) else: mdV_tma_tensor = mdV mdK_tma_tensor = mdK tma_atom_dV = None tma_atom_dK = None if const_expr(not self.dKV_postprocess): thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads val_layout_r2s_dKV = cute.make_ordered_layout( (1, 128 // self.dk_dtype.width), order=(1, 0) ) # 4 or 8 vals for 16 byte store copy_atom_r2s_dKV = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=128, ) tiled_copy_r2s_dKV = cute.make_tiled_copy_tv( copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV ) else: tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d( Float32, 128, num_copy_elems=128 // Float32.width ) tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) # S.T = K @ Q.T tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mK, cute.select(self.sK_layout, mode=[0, 1, 2]), self.mma_tiler_kq, self.tiled_mma_S, self.cluster_layout_vmnk.shape, ) Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma_S.thr_id ) tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B( Q_tma_op, mQ, cute.select(self.sQ_layout, mode=[0, 1, 2]), self.mma_tiler_kq, self.tiled_mma_S, self.cluster_layout_vmnk.shape, ) # dP.T = V @ dO.T tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mV, cute.select(self.sV_layout, mode=[0, 1, 2]), self.mma_tiler_vdo, self.tiled_mma_dP, self.cluster_layout_vmnk.shape, ) # dV = P.T @ dO dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma_dV.thr_id ) tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B( dO_tma_op, mdO, cute.select(self.sdO_layout, mode=[0, 1, 2]), self.mma_tiler_pdo, self.tiled_mma_dV, self.cluster_layout_vmnk.shape, ) # ------------------------------------------------------------ # 2-CTA # ------------------------------------------------------------ tma_atom_dOt = tma_tensor_dOt = None if const_expr(self.use_2cta_instrs): tma_atom_dOt, tma_tensor_dOt = cute.nvgpu.make_tiled_tma_atom_B( dO_tma_op, layout_utils.select(mdO, mode=transpose_sh_q), cute.select(self.sdOt_layout, mode=[0, 1, 2]), self.mma_tiler_vdo, self.tiled_mma_dP, self.cluster_layout_vmnk.shape, ) tma_atom_Qt = tma_tensor_Qt = None if const_expr(self.use_2cta_instrs): tma_atom_Qt, tma_tensor_Qt = cute.nvgpu.make_tiled_tma_atom_B( Q_tma_op, layout_utils.select(mQ, mode=transpose_sh_q), cute.select(self.sQt_layout, mode=[0, 1, 2]), self.mma_tiler_dsq, self.tiled_mma_dK, self.cluster_layout_vmnk.shape, ) tma_atom_Kt = tma_tensor_Kt = None if const_expr(self.use_2cta_instrs): Kt_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B( self.cluster_shape_mnk, self.tiled_mma_dQ.thr_id ) tma_atom_Kt, tma_tensor_Kt = cute.nvgpu.make_tiled_tma_atom_B( Kt_tma_op, layout_utils.select(mK, mode=transpose_sh_k), cute.select(self.sKt_layout, mode=[0, 1, 2]), self.mma_tiler_dsk, self.tiled_mma_dQ, self.cluster_layout_vmnk.shape, ) self.tma_copy_bytes = { name: self.cta_group_size * cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) for name, mX, layout in [ ("Q", mQ, self.sQ_layout), ("K", mK, self.sK_layout), ("V", mV, self.sV_layout), ("dO", mdO, self.sdO_layout), ] } self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8 self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8 self.tma_copy_bytes["dS"] = cute.size_in_bytes(self.ds_dtype, self.sdS_layout) self.tma_copy_bytes["sdS_xchg"] = self.tma_copy_bytes["dS"] // 2 # Half of dS for exchange # TileScheduler = SingleTileScheduler if const_expr(self.is_varlen_k): TileScheduler = SingleTileVarlenScheduler elif const_expr(self.deterministic): TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads cute.size(mK.shape[3]) if const_expr(mCuSeqlensK is None) else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches 1, # num_splits cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k mQ.shape[1], # headdim mV.shape[1], # headdim_v total_q=cute.size(mK.shape[0]) # pass total_k for total_q if const_expr(mCuSeqlensK is not None) else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m) cluster_shape_mn=self.cluster_shape_mnk[:2], mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, # persistent mode not tested lpt=self.spt, head_swizzle=self.deterministic, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) # Compute allocation sizes for shared buffers that are reused # sQ is reused for sdK, sdO is reused for sdV sQ_alloc_bytes = max( cute.size_in_bytes(self.q_dtype, self.sQ_layout), cute.size_in_bytes(self.dk_dtype, self.sdK_layout), ) sdO_alloc_bytes = max( cute.size_in_bytes(self.dv_dtype, self.sdV_layout), cute.size_in_bytes(self.do_dtype, self.sdO_layout), ) sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdK_layout) sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdV_layout) assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation" assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation" # 2-CTA: sdV reuses sV, sdK reuses sK sV_bytes = cute.size_in_bytes(self.v_dtype, self.sV_layout) sK_bytes = cute.size_in_bytes(self.k_dtype, self.sK_layout) if const_expr(self.use_2cta_instrs): assert sdV_bytes <= sV_bytes, "sdV doesn't fit in sV storage allocation (2-CTA)" assert sdK_bytes <= sK_bytes, "sdK doesn't fit in sK storage allocation (2-CTA)" if const_expr(self.use_2cta_instrs): sQt_size = cute.cosize(self.sQt_layout) if const_expr(self.tile_hdim <= 128) else 0 sdOt_size = cute.cosize(self.sdOt_layout) if const_expr(self.tile_hdim <= 128) else 0 sdS_xchg_size = ( cute.cosize(self.sdS_xchg_layout) if const_expr(self.tile_hdim <= 128) else 0 ) @cute.struct class SharedStorage: Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage] dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.dQaccum_reduce_stage // 2 ] dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.dQaccum_reduce_stage // 2 ] tmem_holding_buf: Int32 tmem_dealloc_mbar_ptr: cutlass.Int64 # 2-CTA Qt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] Kt_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dS_cluster_empty_mbar_ptr: cutlass.Int64 dS_cluster_full_mbar_ptr: cutlass.Int64 dS_cluster_leader_mbar_ptr: cutlass.Int64 dQaccum_empty_mbar_ptr: cutlass.Int64 sQ: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, cute.cosize(self.sQ_layout)], self.buffer_align_bytes, ] sK: cute.struct.Align[ cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], self.buffer_align_bytes, ] sV: cute.struct.Align[ cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], self.buffer_align_bytes, ] sdO: cute.struct.Align[ cute.struct.MemRange[self.do_dtype, cute.cosize(self.sdO_layout)], self.buffer_align_bytes, ] sQt: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, sQt_size], self.buffer_align_bytes, ] sdOt: cute.struct.Align[ cute.struct.MemRange[self.do_dtype, sdOt_size], self.buffer_align_bytes, ] sdS_xchg: cute.struct.Align[ cute.struct.MemRange[self.ds_dtype, sdS_xchg_size], self.buffer_align_bytes, ] sKt: cute.struct.Align[ cute.struct.MemRange[self.k_dtype, cute.cosize(self.sKt_layout)], self.buffer_align_bytes, ] sdS: cute.struct.Align[ cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], self.buffer_align_bytes, ] sLSE: cute.struct.Align[ cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], 128, ] sdPsum: cute.struct.Align[ cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], 128, ] sdQaccum: cute.struct.Align[ cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], self.buffer_align_bytes if sdS_xchg_size == 0 else 128, ] else: @cute.struct class SharedStorage: Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage] dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage] S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.single_stage] dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.sdKVaccum_stage] dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] dQ_cluster_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.dQaccum_reduce_stage // 2 ] dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.dQaccum_reduce_stage // 2 ] tmem_holding_buf: Int32 tmem_dealloc_mbar_ptr: Int64 sQ: cute.struct.Align[ cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes], self.buffer_align_bytes, ] sK: cute.struct.Align[ cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)], self.buffer_align_bytes, ] sV: cute.struct.Align[ cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)], self.buffer_align_bytes, ] sdO: cute.struct.Align[ cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes], self.buffer_align_bytes, ] sdS: cute.struct.Align[ cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)], 128, ] sLSE: cute.struct.Align[ cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)], 128, ] sdPsum: cute.struct.Align[ cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)], 128, ] sdQaccum: cute.struct.Align[ cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)], self.buffer_align_bytes, ] self.shared_storage = SharedStorage LOG2_E = math.log2(math.e) if const_expr(self.score_mod is None): # Without score_mod: bake scale into log2 softmax_scale_log2 = softmax_scale * LOG2_E else: # With score_mod: score_mod applied to S * softmax_scale, then use LOG2_E only softmax_scale_log2 = LOG2_E if const_expr(window_size_left is not None): window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): window_size_right = Int32(window_size_right) fastdiv_mods = None if const_expr(aux_tensors is not None): seqlen_q = cute.size(mQ.shape[0]) // ( self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1 ) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmodDivisor(seqlen_q) seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) if const_expr(self.use_2cta_instrs): assert blocksparse_tensors is None, ( "2-CTA mode does not support block sparsity. " "Please create kernel with use_2cta_instrs=False for block sparse attention." ) # 2-CTA: 231424 and 1-CTA: 232448 # print("SMEM: ", self.shared_storage.size_in_bytes()) if const_expr(self.use_block_sparsity or aux_tensors is not None): assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( "Variable sequence length is not supported yet for blocksparse or aux tensors in bwd" ) self.kernel( tma_tensor_Q, tma_tensor_Qt, tma_tensor_K, tma_tensor_Kt, tma_tensor_V, mLSE, mdPsum, tma_tensor_dO, tma_tensor_dOt, mdV, mdK, mdQaccum, mdV_tma_tensor, mdK_tma_tensor, mdQ_semaphore, mdK_semaphore, mdV_semaphore, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, tma_atom_Q, tma_atom_Qt, tma_atom_K, tma_atom_Kt, tma_atom_V, tma_atom_dO, tma_atom_dOt, tma_atom_dV, tma_atom_dK, self.sQ_layout, self.sQt_layout, self.sK_layout, self.sKt_layout, self.sV_layout, self.sLSE_layout, self.sdPsum_layout, self.sdO_layout, self.sdOt_layout, self.sdSt_layout, self.sdS_layout, self.sdS_xchg_layout, self.sdQaccum_layout, self.sdK_layout, self.sdV_layout, self.tP_layout, self.tdS_layout, self.tiled_mma_S, self.tiled_mma_dP, self.tiled_mma_dV, self.tiled_mma_dK, self.tiled_mma_dQ, tiled_copy_r2s_dKV, softmax_scale, softmax_scale_log2, window_size_left, window_size_right, tile_sched_params, aux_tensors, fastdiv_mods, blocksparse_tensors, ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @cute.kernel def kernel( self, mQ: cute.Tensor, mQt: Optional[cute.Tensor], mK: cute.Tensor, mKt: Optional[cute.Tensor], mV: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdO: cute.Tensor, mdOt: Optional[cute.Tensor], mdV: cute.Tensor, mdK: cute.Tensor, mdQaccum: cute.Tensor, mdV_tma_tensor: Optional[cute.Tensor], mdK_tma_tensor: Optional[cute.Tensor], mdQ_semaphore: Optional[cute.Tensor], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_Qt: Optional[cute.CopyAtom], tma_atom_K: cute.CopyAtom, tma_atom_Kt: Optional[cute.CopyAtom], tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, tma_atom_dOt: Optional[cute.CopyAtom], tma_atom_dV: Optional[cute.CopyAtom], tma_atom_dK: Optional[cute.CopyAtom], sQ_layout: cute.ComposedLayout, sQt_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sKt_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sLSE_layout: cute.Layout, sdPsum_layout: cute.Layout, sdO_layout: cute.ComposedLayout, sdOt_layout: cute.ComposedLayout, sdSt_layout: cute.ComposedLayout, sdS_layout: cute.ComposedLayout, sdS_xchg_layout: cute.Layout, sdQaccum_layout: cute.Layout, sdK_layout: cute.ComposedLayout | cute.Layout, sdV_layout: cute.ComposedLayout | cute.Layout, tP_layout: cute.ComposedLayout, tdS_layout: cute.ComposedLayout, tiled_mma_S: cute.TiledMma, tiled_mma_dP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, tiled_copy_r2s_dKV: cute.TiledCopy, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, window_size_left: Optional[Int32], window_size_right: Optional[Int32], tile_sched_params: ParamsBase, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) bidx, _, _ = cute.arch.block_idx() mma_tile_coord_v = bidx % self.cta_group_size is_leader_cta = mma_tile_coord_v == 0 cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) # Prefetch tma descriptor if warp_idx == self.load_warp_id: with cute.arch.elect_one(): cpasync.prefetch_descriptor(tma_atom_Q) if const_expr(tma_atom_Qt is not None): cpasync.prefetch_descriptor(tma_atom_Qt) cpasync.prefetch_descriptor(tma_atom_K) if const_expr(tma_atom_Kt is not None): cpasync.prefetch_descriptor(tma_atom_Kt) cpasync.prefetch_descriptor(tma_atom_V) if const_expr(tma_atom_dOt is not None): cpasync.prefetch_descriptor(tma_atom_dOt) cpasync.prefetch_descriptor(tma_atom_dO) if const_expr(tma_atom_dV is not None): cpasync.prefetch_descriptor(tma_atom_dV) if const_expr(tma_atom_dK is not None): cpasync.prefetch_descriptor(tma_atom_dK) cluster_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (tiled_mma_S.thr_id.shape,), ) # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr() dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr() if const_expr(self.use_2cta_instrs): dS_cluster_full_mbar_ptr = storage.dS_cluster_full_mbar_ptr dS_cluster_empty_mbar_ptr = storage.dS_cluster_empty_mbar_ptr dS_cluster_leader_mbar_ptr = storage.dS_cluster_leader_mbar_ptr dQaccum_empty_mbar_ptr = storage.dQaccum_empty_mbar_ptr else: dS_cluster_full_mbar_ptr = None dS_cluster_empty_mbar_ptr = None dS_cluster_leader_mbar_ptr = None dQaccum_empty_mbar_ptr = None # Barrier initialization if const_expr(self.use_2cta_instrs): if const_expr(self.tile_hdim == 192): if warp_idx == 2: cute.arch.mbarrier_init( dQaccum_empty_mbar_ptr, len(self.reduce_warp_ids), ) if warp_idx == 4: cute.arch.mbarrier_init(dS_cluster_full_mbar_ptr, 1) cute.arch.mbarrier_init(dS_cluster_empty_mbar_ptr, 1) cute.arch.mbarrier_init(dS_cluster_leader_mbar_ptr, 2) if const_expr(self.cluster_reduce_dQ): if warp_idx == 4: for i in range(self.dQaccum_reduce_stage // 2): cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1) cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1) tmem_alloc_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwdSm100.TmemPtr), num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.compute_warp_ids, *self.reduce_warp_ids)), ) tmem = cutlass.utils.TmemAllocator( storage.tmem_holding_buf, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.mma_warp_id, is_two_cta=self.use_2cta_instrs, two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, ) # UMMA producers and AsyncThread consumers pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.cta_group_size ) pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.S_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dP_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=2, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread, barrier_storage=storage.dKV_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len(self.reduce_warp_ids) * self.cta_group_size, ) # Compute pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create( num_stages=1, producer_group=pipeline_producer_group_MMA_AsyncThread, consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ, barrier_storage=storage.dQ_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) # AsyncThread producers and UMMA consumers # Only 1 thread per warp will signal pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.cta_group_size, ) # Compute pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) ) # MMA pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create( num_stages=1, producer_group=pipeline_PdS_producer_group, consumer_group=pipeline_PdS_consumer_group, barrier_storage=storage.dS_mbar_ptr.data_ptr(), cta_layout_vmnk=cluster_layout_vmnk, ) # TMA producer and UMMA consumers pipeline_producer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.load_warp_id]) ) # The arrive count is the number of mcast size pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b ) pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * 1, ) pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create( barrier_storage=storage.LSE_mbar_ptr.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group_compute, tx_count=self.tma_copy_bytes["LSE"], # cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create( barrier_storage=storage.dPsum_mbar_ptr.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group_compute, tx_count=self.tma_copy_bytes["dPsum"], # cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) pipeline_Q = pipeline.PipelineTmaUmma.create( barrier_storage=storage.Q_mbar_ptr.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"], cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) if const_expr(self.use_2cta_instrs): if const_expr(self.tile_hdim == 192): pipeline_Qt = pipeline_Q else: pipeline_Qt = pipeline.PipelineTmaUmma.create( barrier_storage=storage.Qt_mbar_ptr.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"], cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) pipeline_Kt = pipeline.PipelineTmaUmma.create( barrier_storage=storage.Kt_mbar_ptr.data_ptr(), num_stages=self.single_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["K"], cta_layout_vmnk=cluster_layout_vmnk, defer_sync=True, ) else: pipeline_Qt = pipeline_Kt = pipeline_Q pipeline_dO = pipeline.PipelineTmaUmma.create( barrier_storage=storage.dO_mbar_ptr.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"], cta_layout_vmnk=cluster_layout_vmnk, defer_sync=False, ) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype) if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): sQt = storage.sQt.get_tensor( sQt_layout.outer, swizzle=sQt_layout.inner, dtype=self.q_dtype ) else: sQt = cute.make_tensor( cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer ) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) if const_expr(self.use_2cta_instrs): sKt = storage.sKt.get_tensor(sKt_layout.outer, swizzle=sKt_layout.inner) else: sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner) sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer) if const_expr(self.use_2cta_instrs): if const_expr(self.tile_hdim <= 128): sdS_xchg = storage.sdS_xchg.get_tensor(sdS_xchg_layout) else: sdS_xchg = storage.sdQaccum.get_tensor(sdS_xchg_layout, dtype=self.ds_dtype) else: sdS_xchg = None sdO = storage.sdO.get_tensor( sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype ) if const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): sdOt = storage.sdOt.get_tensor( sdOt_layout.outer, swizzle=sdOt_layout.inner, dtype=self.do_dtype ) else: sdOt = cute.make_tensor( cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer, ) sLSE = storage.sLSE.get_tensor(sLSE_layout) sdPsum = storage.sdPsum.get_tensor(sdPsum_layout) if const_expr(self.use_2cta_instrs): if const_expr(not self.dKV_postprocess): sdV = storage.sV.get_tensor( sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype ) sdK = storage.sK.get_tensor( sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype ) else: sdV = storage.sV.get_tensor(sdV_layout, dtype=self.dv_dtype) sdK = storage.sK.get_tensor(sdK_layout, dtype=self.dk_dtype) elif const_expr(not self.dKV_postprocess): sdV = storage.sdO.get_tensor( sdV_layout.outer, swizzle=sdV_layout.inner, dtype=self.dv_dtype ) sdK = storage.sQ.get_tensor( sdK_layout.outer, swizzle=sdK_layout.inner, dtype=self.dk_dtype ) else: sdV = storage.sdO.get_tensor(sdV_layout, dtype=self.dv_dtype) sdK = storage.sQ.get_tensor(sdK_layout, dtype=self.dk_dtype) # Buffer sizing is guaranteed by max(...) in SharedStorage declarations # for both sQ (reused as sdK) and sdO (reused as sdV) sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) # TMEM # This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16) # S thr_mma_S = tiled_mma_S.get_slice(mma_tile_coord_v) Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N) tStS = thr_mma_S.make_fragment_C(Sacc_shape) # (MMA, MMA_M, MMA_N) tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout) # dP thr_mma_dP = tiled_mma_dP.get_slice(mma_tile_coord_v) dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2]) tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape) tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout) # dV thr_mma_dV = tiled_mma_dV.get_slice(mma_tile_coord_v) dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2]) tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape) tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout) tP = cute.make_tensor( cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer ) # dK thr_mma_dK = tiled_mma_dK.get_slice(mma_tile_coord_v) dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2]) tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape) tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout) tdS = cute.make_tensor( cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer ) # dQ thr_mma_dQ = tiled_mma_dQ.get_slice(mma_tile_coord_v) dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2]) tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape) tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout) block_info = BlockInfo( self.tile_m, # self.tile_n, self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested self.is_causal, self.is_local, False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, tile_m=self.tile_m, tile_n=self.tile_n * self.cluster_shape_mnk[0], ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n * self.cta_group_size, swap_AB=True, window_size_left=window_size_left, window_size_right=window_size_right, ) # EMPTY # (15) if warp_idx == self.empty_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_empty) # RELAY # (14) if warp_idx == self.relay_warp_id: cute.arch.setmaxregister_decrease( self.num_regs_mma if self.use_2cta_instrs else self.num_regs_empty ) if const_expr(self.use_2cta_instrs): self.relay( dS_cluster_full_mbar_ptr, dS_cluster_empty_mbar_ptr, dS_cluster_leader_mbar_ptr, cluster_layout_vmnk, block_info, SeqlenInfoCls, TileSchedulerCls, ) # LOAD # (13) if warp_idx == self.load_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_load) self.load( thr_mma_S, thr_mma_dP, thr_mma_dV, thr_mma_dK, thr_mma_dQ, mQ, mK, mKt, mV, mdO, mQt, mdOt, mLSE, mdPsum, sQ, sK, sKt, sV, sdO, sQt, sdOt, sLSE, sdPsum, tma_atom_Q, tma_atom_K, tma_atom_Kt, tma_atom_V, tma_atom_dO, tma_atom_Qt, tma_atom_dOt, pipeline_Q, pipeline_Qt, pipeline_Kt, pipeline_dO, pipeline_LSE, pipeline_dPsum, cluster_layout_vmnk, block_info, SeqlenInfoCls, TileSchedulerCls, blocksparse_tensors, should_load_Q=True, should_load_dO=True, ) # MMA # (12) if warp_idx == self.mma_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_mma) # Alloc tmem buffer tmem.allocate(self.tmem_alloc_cols) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(Float32) self.mma( tiled_mma_S, tiled_mma_dP, tiled_mma_dV, tiled_mma_dK, tiled_mma_dQ, sQ, sQt, sK, sKt, sV, sdO, sdOt, tP, sdSt, sdS, tdS, tStS, tdPtdP, tdVtdV, tdKtdK, tdQtdQ, dS_cluster_full_mbar_ptr, dS_cluster_empty_mbar_ptr, dS_cluster_leader_mbar_ptr, pipeline_Q, pipeline_Qt, pipeline_Kt, pipeline_dO, pipeline_S_P, pipeline_dS, pipeline_dKV, pipeline_dP, pipeline_dQ, block_info, SeqlenInfoCls, TileSchedulerCls, is_leader_cta, blocksparse_tensors, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() tmem_alloc_barrier.arrive_and_wait() tmem.free(tmem_ptr) # Compute # (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]: cute.arch.setmaxregister_increase(self.num_regs_compute) # 8 warps tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(Float32) self.compute_loop( thr_mma_S, thr_mma_dP, thr_mma_dV, thr_mma_dK, tStS, tdPtdP, tdVtdV, tdKtdK, sLSE, sdPsum, mdV, mdK, sdS, sdS_xchg, pipeline_LSE, pipeline_dPsum, pipeline_S_P, pipeline_dS, pipeline_dKV, pipeline_dP, dS_cluster_empty_mbar_ptr, dS_cluster_full_mbar_ptr, dQaccum_empty_mbar_ptr, softmax_scale, softmax_scale_log2, block_info, SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, sdV, sdK, mdV_tma_tensor, mdK_tma_tensor, tma_atom_dV, tma_atom_dK, tiled_copy_r2s_dKV, mdK_semaphore, mdV_semaphore, aux_tensors, fastdiv_mods, blocksparse_tensors, ) tmem_alloc_barrier.arrive() # Reduce # (0, 1, 2, 3) - dQ if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]: cute.arch.setmaxregister_increase(self.num_regs_reduce) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(Float32) self.dQacc_reduce( mdQaccum, sdQaccum, thr_mma_dQ, tdQtdQ, pipeline_dQ, dQaccum_empty_mbar_ptr, block_info, SeqlenInfoCls, TileSchedulerCls, mdQ_semaphore, blocksparse_tensors, ) tmem_alloc_barrier.arrive() return @cute.jit def relay( self, dS_cluster_full_mbar_ptr: cute.Pointer, dS_cluster_empty_mbar_ptr: cute.Pointer, dS_cluster_leader_mbar_ptr: cute.Pointer, cluster_layout_vmnk: cute.Layout, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) dS_cluster_phase = Int32(0) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) head_idx_kv = head_idx // self.qhead_per_kvhead process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) if process_tile: num_iters = m_block_max - m_block_min for _ in cutlass.range(num_iters, unroll=1): # Wait for dS_xchg from peer CTA cute.arch.mbarrier_wait(dS_cluster_full_mbar_ptr, phase=dS_cluster_phase) # Arrive on MMA leader warp with cute.arch.elect_one(): cute.arch.mbarrier_arrive(dS_cluster_leader_mbar_ptr, Int32(0)) dS_cluster_phase ^= 1 tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @cute.jit def load( self, thr_mma_S: cute.core.ThrMma, thr_mma_dP: cute.core.ThrMma, thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, thr_mma_dQ: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mKt: Optional[cute.Tensor], mV: cute.Tensor, mdO: cute.Tensor, mQt: Optional[cute.Tensor], mdOt: Optional[cute.Tensor], mLSE: cute.Tensor, mdPsum: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sKt: cute.Tensor, sV: cute.Tensor, sdO: cute.Tensor, sQt: cute.Tensor, sdOt: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_Kt: Optional[cute.CopyAtom], tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, tma_atom_Qt: Optional[cute.CopyAtom], tma_atom_dOt: Optional[cute.CopyAtom], # 2-CTA only pipeline_Q: PipelineAsync, pipeline_Qt: PipelineAsync, pipeline_Kt: PipelineAsync, pipeline_dO: PipelineAsync, pipeline_LSE: PipelineAsync, pipeline_dPsum: PipelineAsync, cluster_layout_vmnk: cute.Layout, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, should_load_Q: bool = True, should_load_dO: bool = True, ): producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_Qt = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_Kt = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.single_stage ) producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) producer_state_Q_Qt = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_O_Ot = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) producer_state_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_dPsum = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) # Compute multicast mask for Q & dO buffer full cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) q_do_mcast_mask = None if const_expr(self.is_q_do_mcast): q_do_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) head_idx_kv = head_idx // self.qhead_per_kvhead n_block_cta_group = n_block // self.cta_group_size # GMEM tensors (varlen-aware) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] if const_expr(not seqlen.has_cu_seqlens_q): mdO_cur = mdO[None, None, head_idx, batch_idx] else: mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx]) mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx] mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ None, head_idx ] if const_expr(self.use_2cta_instrs): if const_expr(not seqlen.has_cu_seqlens_q): mQt_cur = mQt[None, None, head_idx, batch_idx] mdOt_cur = mdOt[None, None, head_idx, batch_idx] else: mQt_cur = cute.domain_offset((0, seqlen.offset_q, 0), mQt)[None, None, head_idx] mdOt_cur = cute.domain_offset((seqlen.offset_q, 0, 0), mdOt)[ None, None, head_idx ] if const_expr(not seqlen.has_cu_seqlens_k): mKt_cur = mKt[None, None, head_idx_kv, batch_idx] else: mKt_cur = cute.domain_offset((0, seqlen.offset_k, 0), mKt)[ None, None, head_idx_kv ] # (1) S.T = K @ Q.T gK = cute.local_tile( mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block_cta_group, 0) ) tSgK = thr_mma_S.partition_A(gK) gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0)) tSgQ = thr_mma_S.partition_B(gQ) gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdPgdO = thr_mma_dV.partition_B(gdO) a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, block_in_cluster_coord_vmnk[2], a_cta_layout, tSgK, sK, single_stage=True, ) b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) load_Q, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, cta_coord=block_in_cluster_coord_vmnk[1], cta_layout=b_cta_layout, src_tensor=tSgQ, dst_tensor=sQ, mcast_mask=q_do_mcast_mask, ) load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) # (2) dP = V @ dO.T gV = cute.local_tile( mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block_cta_group, 0) ) tdPgV = thr_mma_dP.partition_A(gV) load_V, _, _ = copy_utils.tma_get_copy_fn( tma_atom_V, 0, cute.make_layout(1), tdPgV, sV, single_stage=True, ) if const_expr(tma_atom_dOt is not None): gdOt = cute.local_tile( mdOt_cur, cute.select(self.mma_tiler_vdo, mode=[1, 2]), (None, 0) ) tdPgdO = thr_mma_dP.partition_B(gdOt) load_dOt, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dOt, cta_coord=block_in_cluster_coord_vmnk[1], cta_layout=b_cta_layout, src_tensor=tdPgdO, dst_tensor=sdOt, mcast_mask=q_do_mcast_mask, ) load_dOt = copy_utils.tma_producer_copy_fn(load_dOt, pipeline_dO) # (3) dV += P.T @ dO gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None)) tdVgdO = thr_mma_dV.partition_B(gdO) load_dO, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dO, cta_coord=block_in_cluster_coord_vmnk[1], cta_layout=b_cta_layout, src_tensor=tdVgdO, dst_tensor=sdO, mcast_mask=q_do_mcast_mask, ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) # (4) dK += dS.T @ Q (2-CTA: needs separate Qt load) if const_expr(tma_atom_Qt is not None): gQt = cute.local_tile( mQt_cur, cute.select(self.mma_tiler_dsq, mode=[1, 2]), (0, None) ) tdKgQt = thr_mma_dK.partition_B(gQt) load_Qt, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Qt, cta_coord=block_in_cluster_coord_vmnk[1], cta_layout=b_cta_layout, src_tensor=tdKgQt, dst_tensor=sQt, mcast_mask=q_do_mcast_mask, ) load_Qt = copy_utils.tma_producer_copy_fn(load_Qt, pipeline_Qt) # (5) dQ = dS @ K if const_expr(self.use_2cta_instrs): gKt = cute.local_tile( mKt_cur, cute.select(self.mma_tiler_dsk, mode=[1, 2]), (0, n_block_cta_group) ) tdQgK = thr_mma_dQ.partition_B(gKt) load_Kt, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Kt, block_in_cluster_coord_vmnk[1], b_cta_layout, tdQgK, sKt, single_stage=True, ) copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32) copy_stats = partial(cute.copy, copy_atom_stats) # copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32) # sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) # some tiles might be empty due to block sparsity if const_expr(self.use_block_sparsity): total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = total_m_block_cnt > Int32(0) else: process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) if process_tile: if const_expr(self.use_block_sparsity): producer_state_Q_LSE, producer_state_dO_dPsum = ( produce_block_sparse_q_loads_bwd_sm100( blocksparse_tensors, batch_idx, head_idx, n_block, producer_state_Q_LSE, producer_state_dO_dPsum, pipeline_Q, pipeline_LSE, pipeline_dO, pipeline_dPsum, load_K, load_V, load_Q, load_dO, copy_stats, gLSE, sLSE, gdPsum, sdPsum, self.tma_copy_bytes["K"], self.tma_copy_bytes["V"], should_load_Q=should_load_Q, should_load_dO=should_load_dO, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) ) else: first_m_block = m_block_min if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): #### Prologue #### assert should_load_Q and should_load_dO # K & Q (for S) pipeline_Q.producer_acquire( producer_state_Q_Qt, extra_tx_count=self.tma_copy_bytes["K"], ) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_Qt)) load_Q(first_m_block, producer_state=producer_state_Q_Qt) pipeline_Q.producer_commit(producer_state_Q_Qt) producer_state_Q_Qt.advance() # LSE pipeline_LSE.producer_acquire(producer_state_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, first_m_block], sLSE[None, producer_state_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE), ) producer_state_LSE.advance() # dOt + V, for dP.T = V @ dO.T pipeline_dO.producer_acquire( producer_state_O_Ot, extra_tx_count=self.tma_copy_bytes["V"], ) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_O_Ot)) load_dOt(first_m_block, producer_state=producer_state_O_Ot) pipeline_dO.producer_commit(producer_state_O_Ot) producer_state_O_Ot.advance() # dPsum pipeline_dPsum.producer_acquire(producer_state_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, first_m_block], sdPsum[None, producer_state_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dPsum), ) producer_state_dPsum.advance() # Qt, for dK = dS.T @ Q pipeline_Qt.producer_acquire( producer_state_Q_Qt, extra_tx_count=self.tma_copy_bytes["K"], ) load_Qt(first_m_block, producer_state=producer_state_Q_Qt) load_Kt(tma_bar_ptr=pipeline_Qt.producer_get_barrier(producer_state_Q_Qt)) pipeline_Qt.producer_commit(producer_state_Q_Qt) producer_state_Q_Qt.advance() # dO, for dV = P.T @ dO pipeline_dO.producer_acquire(producer_state_O_Ot) load_dO(first_m_block, producer_state=producer_state_O_Ot) pipeline_dO.producer_commit(producer_state_O_Ot) producer_state_O_Ot.advance() #### Mainloop #### # 2CTA: [lse | Q | dOt | dPsum | Qt | dO] for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): # LSE pipeline_LSE.producer_acquire(producer_state_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block], sLSE[None, producer_state_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_LSE), ) producer_state_LSE.advance() # Q pipeline_Q.producer_acquire(producer_state_Q_Qt) load_Q(m_block, producer_state=producer_state_Q_Qt) pipeline_Q.producer_commit(producer_state_Q_Qt) producer_state_Q_Qt.advance() # dPsum pipeline_dPsum.producer_acquire(producer_state_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block], sdPsum[None, producer_state_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier( producer_state_dPsum ), ) producer_state_dPsum.advance() # dOt, for dP.T = V @ dO.T pipeline_dO.producer_acquire(producer_state_O_Ot) load_dOt(m_block, producer_state=producer_state_O_Ot) pipeline_dO.producer_commit(producer_state_O_Ot) producer_state_O_Ot.advance() # Qt, for dK = dS.T @ Q pipeline_Qt.producer_acquire(producer_state_Q_Qt) load_Qt(m_block, producer_state=producer_state_Q_Qt) pipeline_Qt.producer_commit(producer_state_Q_Qt) producer_state_Q_Qt.advance() # dO, for dV = P.T @ dO pipeline_dO.producer_acquire(producer_state_O_Ot) load_dO(m_block, producer_state=producer_state_O_Ot) pipeline_dO.producer_commit(producer_state_O_Ot) producer_state_O_Ot.advance() else: #### Prologue #### if const_expr(should_load_Q): # K & Q (for S) pipeline_Q.producer_acquire( producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] ) load_K( tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE) ) load_Q(first_m_block, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, first_m_block], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier( producer_state_Q_LSE ), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire( producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + self.tma_copy_bytes["dO"] if const_expr(tma_atom_dOt is not None) else self.tma_copy_bytes["V"], ) load_V( tma_bar_ptr=pipeline_dO.producer_get_barrier( producer_state_dO_dPsum ) ) load_dO(first_m_block, producer_state=producer_state_dO_dPsum) if const_expr(tma_atom_dOt is not None): load_dOt(first_m_block, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, first_m_block], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier( producer_state_dO_dPsum ), ) producer_state_dO_dPsum.advance() if const_expr(self.use_2cta_instrs): pipeline_Kt.producer_acquire(producer_state_Kt) load_Kt(tma_bar_ptr=pipeline_Kt.producer_get_barrier(producer_state_Kt)) pipeline_Kt.producer_commit(producer_state_Kt) producer_state_Kt.advance() #### Main Loop #### for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): if const_expr(should_load_Q): if const_expr(tma_atom_Qt is not None): pipeline_Qt.producer_acquire(producer_state_Qt) load_Qt(m_block - 1, producer_state=producer_state_Qt) pipeline_Qt.producer_commit(producer_state_Qt) producer_state_Qt.advance() # Q (for S) pipeline_Q.producer_acquire(producer_state_Q_LSE) load_Q(m_block, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( gLSE[None, m_block], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier( producer_state_Q_LSE ), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): pipeline_dO.producer_acquire( producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["dO"] if const_expr(tma_atom_dOt is not None) else 0, ) load_dO(m_block, producer_state=producer_state_dO_dPsum) if const_expr(tma_atom_dOt is not None): load_dOt(m_block, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( gdPsum[None, m_block], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier( producer_state_dO_dPsum ), ) producer_state_dO_dPsum.advance() #### Tail #### if const_expr(should_load_Q): if const_expr(tma_atom_Qt is not None): pipeline_Qt.producer_acquire(producer_state_Qt) load_Qt(m_block_max - 1, producer_state=producer_state_Qt) pipeline_Qt.producer_commit(producer_state_Qt) producer_state_Qt.advance() if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): pipeline_Q.producer_tail(producer_state_Q_Qt) pipeline_LSE.producer_tail(producer_state_LSE) pipeline_dO.producer_tail(producer_state_O_Ot) pipeline_dPsum.producer_tail(producer_state_dPsum) else: if const_expr(should_load_Q): pipeline_Q.producer_tail(producer_state_Q_LSE.clone()) pipeline_LSE.producer_tail(producer_state_Q_LSE) if const_expr(tma_atom_Qt is not None): pipeline_Qt.producer_tail(producer_state_Qt) if const_expr(should_load_dO): pipeline_dO.producer_tail(producer_state_dO_dPsum.clone()) pipeline_dPsum.producer_tail(producer_state_dO_dPsum) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @cute.jit def mma( self, tiled_mma_S: cute.TiledMma, tiled_mma_dP: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, sQ: cute.Tensor, sQt: cute.Tensor, sK: cute.Tensor, sKt: cute.Tensor, sV: cute.Tensor, sdO: cute.Tensor, sdOt: cute.Tensor, tP: cute.Tensor, sdSt: cute.Tensor, sdS: cute.Tensor, tdS: cute.Tensor, tStS: cute.Tensor, tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, tdQtdQ: cute.Tensor, dS_cluster_full_mbar_ptr: cute.Pointer, dS_cluster_empty_mbar_ptr: cute.Pointer, dS_cluster_leader_mbar_ptr: cute.Pointer, pipeline_Q: PipelineAsync, pipeline_Qt: PipelineAsync, pipeline_Kt: PipelineAsync, pipeline_dO: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, pipeline_dQ: PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, is_leader_cta: cutlass.Boolean, blocksparse_tensors: Optional[BlockSparseTensors] = None, ): # [2025-10-21] For reasons I don't understand, putting these partitioning in the main # kernel (before warp specialization) is a lot slower tha putting them here. # Partition smem / tmem tensors # S = K @ Q.T tSrK = tiled_mma_S.make_fragment_A(sK) tSrQ = tiled_mma_S.make_fragment_B(sQ) # dP = V @ dOt.T tdPrV = tiled_mma_dP.make_fragment_A(sV) tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt) # dK = dS.T @ Q # For 2-CTA, dS (dK mma) MUST come from TMEM (cannot use SMEM) if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs): tdKrdS = tiled_mma_dK.make_fragment_A(sdSt) # From SMEM else: tdKrdS = tiled_mma_dK.make_fragment_A(tdS) # From TMEM tdKrQ = tiled_mma_dK.make_fragment_B(sQt) # dQ = dS @ K tdQrdS = tiled_mma_dQ.make_fragment_A(sdS) tdQrK = tiled_mma_dQ.make_fragment_B(sKt) # dV = P @ dO.T tdVrdO = tiled_mma_dV.make_fragment_B(sdO) tdVrP = tiled_mma_dV.make_fragment_A(tP) # mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True) mma_qk_fn = partial( gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True, cta_group=self.cta_group_size, ) # mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True) mma_dov_fn = partial( gemm_ptx_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, sA=sV, sB=sdOt, zero_init=True, cta_group=self.cta_group_size, ) # mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO) mma_pdo_fn = partial( gemm_ptx_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO, sA=None, sB=sdO, tA_addr=self.tmem_P_offset, cta_group=self.cta_group_size, ) num_unroll_groups = 2 if const_expr(self.use_2cta_instrs) else 1 mma_dsk_fn = partial( gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True, num_unroll_groups=num_unroll_groups, ) # mma_dsk_fn = partial( # gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True # ) if const_expr(self.use_smem_dS_for_mma_dK and not self.use_2cta_instrs): mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ) else: # Need to explicitly pass in tA_addr for correctness mma_dsq_fn = partial( gemm_ptx_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ, sA=None, sB=sQt, tA_addr=self.tmem_dS_offset, cta_group=self.cta_group_size, ) pipeline_Q_consumer = pipeline_Q.make_consumer() consumer_state_Qt = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) consumer_state_Kt = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.single_stage ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) producer_phase_acc = Int32(1) # For S & P, dP, dQ producer_phase_dQ = Int32(1) # 2-CTA: separate phase for dQ pipeline consumer_state_dS = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) producer_phase_dKV = Int32(1) cta_group = pipeline_S_P.cta_group cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) dS_cluster_phase = Int32(0) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) if const_expr(self.use_block_sparsity): block_iter_count = get_total_q_block_count_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = block_iter_count > Int32(0) else: block_iter_count = m_block_max - m_block_min process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) if const_expr(self.use_2cta_instrs and self.tile_hdim == 192): if is_leader_cta and process_tile: accumulate_dK = False accumulate_dV = False # ----------------------------------------------------------- ###### MAIN LOOP # ----------------------------------------------------------- # 1. S.T = K @ Q.T # 2. dP.T = V @ dO.T # 3. dK = dS.T @ Q # 4. dV = P.T @ dO # 5. dQ = dS @ K main_loop_iters = m_block_max - m_block_min # empty waits # pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) # pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) for _ in cutlass.range(main_loop_iters, unroll=1): # 1) S.T = K @ Q.T pipeline_Q.consumer_wait(consumer_state_Q) pipeline_dQ.sync_object_empty.wait( 0, producer_phase_acc ) # dQ tmem overlaps with S mma_qk_fn(B_idx=consumer_state_Q.index) pipeline_S_P.sync_object_full.arrive( 0, pipeline_S_P.producer_mask, cta_group ) pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() producer_phase_acc ^= 1 # 2) dP.T = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) pipeline_S_P.sync_object_empty.wait( 0, producer_phase_acc ) # dP tmem overlaps with S mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() # 3) dK = dS.T @ Q pipeline_Q.consumer_wait(consumer_state_Q) pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS mma_dsq_fn(B_idx=consumer_state_Q.index, zero_init=not accumulate_dK) pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() accumulate_dK = True # 4) dV = P.T @ dO # Note: if dS is written to tmem, P must be written to tmem pipeline_dO.consumer_wait(consumer_state_dO) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=not accumulate_dV) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() accumulate_dV = True # 5) dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() dS_cluster_phase ^= 1 # signal to the epilogue that dV is ready pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) # signal to the epilogue that dK is ready pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) producer_phase_dKV ^= 1 elif const_expr(self.use_2cta_instrs): if is_leader_cta and process_tile: accumulate_dK = False # ----------------------------------------------------------- ###### Prologue # ----------------------------------------------------------- # 1. S = Q0 @ K.T # 2. dP = V @ dOt.T # 3. dV = P @ dO # 1) S = K @ Q pipeline_Q.consumer_wait(consumer_state_Q) pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_qk_fn(B_idx=consumer_state_Q.index) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() # 2) dP = V @ dOt.T pipeline_dO.consumer_wait(consumer_state_dO) pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) # 3) dV = P.T @ dO producer_phase_acc ^= 1 pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() pipeline_Kt.consumer_wait(consumer_state_Kt) # ----------------------------------------------------------- ###### MAIN LOOP # ----------------------------------------------------------- # 1. S.T = K @ Q.T # 2. dK = dS.T @ Q # 3. dP.T = V @ dO.T # 4. dQ = dS @ K # 5. dV = P.T @ dO main_loop_iters = ( block_iter_count - 1 if const_expr(self.use_block_sparsity) else m_block_max - m_block_min - 1 ) for _ in cutlass.range(main_loop_iters, unroll=1): # (1) S.T = K @ Q.T (next) pipeline_Q.consumer_wait(consumer_state_Q) pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) mma_qk_fn(B_idx=consumer_state_Q.index) pipeline_S_P.sync_object_full.arrive( 0, pipeline_S_P.producer_mask, cta_group ) pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() # pipeline_dS.consumer_wait(consumer_state_dS) # (2) dK += dS.T @ Q (cur) pipeline_Qt.consumer_wait(consumer_state_Qt) pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK) accumulate_dK = True pipeline_Qt.consumer_release(consumer_state_Qt) consumer_state_Qt.advance() # (3) dP.T = V @ dO.T (next) pipeline_dO.consumer_wait(consumer_state_dO) mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) # (5) dQ = dS @ K (cur) pipeline_dS.consumer_wait(consumer_state_dS) cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() dS_cluster_phase ^= 1 producer_phase_dQ ^= 1 # (4) dV += P.T @ dO (next) producer_phase_acc ^= 1 pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) # S -> P mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) # signal to the epilogue that dV is ready pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) # ----------------------------------------------------------- # Tail: Remaining dK and dQ # ----------------------------------------------------------- # pipeline_dS.consumer_wait(consumer_state_dS) # dK += dS.T @ Q pipeline_Qt.consumer_wait(consumer_state_Qt) pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) # dP -> dS mma_dsq_fn(B_idx=consumer_state_Qt.index, zero_init=not accumulate_dK) pipeline_Qt.consumer_release(consumer_state_Qt) consumer_state_Qt.advance() # signal to the epilogue that dK is ready pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) producer_phase_dKV ^= 1 # dQ = dS @ K pipeline_dS.consumer_wait(consumer_state_dS) cute.arch.mbarrier_wait(dS_cluster_leader_mbar_ptr, phase=dS_cluster_phase) pipeline_dQ.sync_object_empty.wait(0, producer_phase_dQ) mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) pipeline_dS.consumer_release(consumer_state_dS) pipeline_Kt.consumer_release(consumer_state_Kt) consumer_state_dS.advance() consumer_state_Kt.advance() dS_cluster_phase ^= 1 producer_phase_dQ ^= 1 producer_phase_acc ^= 1 else: if is_leader_cta and process_tile: accumulate_dK = False # ----------------------------------------------------------- ###### Prologue # ----------------------------------------------------------- # 1. S = Q0 @ K.T # 2. dP = V @ dOt.T # 3. dV = P @ dO # 1) S = K @ Q handle_Q = pipeline_Q_consumer.wait_and_advance() pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_qk_fn(B_idx=handle_Q.index) pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) # 2) dP = V @ dOt.T pipeline_dO.consumer_wait(consumer_state_dO) pipeline_dP.sync_object_empty.wait(0, producer_phase_acc) pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) producer_phase_acc ^= 1 # 3) dV = P.T @ dO pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() # ----------------------------------------------------------- ###### MAIN LOOP # ----------------------------------------------------------- # 1. S = K @ Q.T # 2. dQ = dS @ K # 3. dK = dS.T @ Q # 4. dP = V @ dOt.T # 5. dV = P.T @ dO # For block sparsity, we use block_iter_count; for dense, use m_block range # MMA doesn't need actual m_block indices, just the iteration count main_loop_iters = ( block_iter_count - 1 if const_expr(self.use_block_sparsity) else m_block_max - m_block_min - 1 ) handle_Q_next = handle_Q for _ in cutlass.range(main_loop_iters, unroll=1): # (1) S.T = K @ Q.T handle_Q_next = pipeline_Q_consumer.wait_and_advance() mma_qk_fn(B_idx=handle_Q_next.index) pipeline_S_P.sync_object_full.arrive( 0, pipeline_S_P.producer_mask, cta_group ) # (2) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) accumulate_dK = True handle_Q.release() # (3) dQ = dS @ K mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() # (4) dP = V @ dO.T pipeline_dO.consumer_wait(consumer_state_dO) pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc) mma_dov_fn(B_idx=consumer_state_dO.index) pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group) # (5) dV += P.T @ dO producer_phase_acc ^= 1 pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc) mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False) pipeline_dO.consumer_release(consumer_state_dO) consumer_state_dO.advance() handle_Q = handle_Q_next pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group) # signal to the epilogue that dV is ready # pipeline_dKV.producer_acquire(producer_state_dKV) pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV) # pipeline_dKV.producer_commit(producer_state_dKV) pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group) # producer_state_dKV.advance() # pipeline_dKV.producer_acquire(producer_state_dKV) pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV) # ----------------------------------------------------------- # Tail: Remaining dK and dQ # ----------------------------------------------------------- # 1) dK += dS.T @ Q pipeline_dS.consumer_wait(consumer_state_dS) mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK) # signal to the epilogue that dK is ready pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group) producer_phase_dKV ^= 1 # 2) dQ = dS @ K mma_dsk_fn() pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group) handle_Q.release() pipeline_dS.consumer_release(consumer_state_dS) consumer_state_dS.advance() producer_phase_acc ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # Currently it hangs if we have this S_P.producer_tail, will need to understand why # pipeline_S_P.producer_tail(producer_state_S_P) # pipeline_dP.producer_tail(producer_state_dP) # pipeline_dKV.producer_tail(producer_state_dKV) # pipeline_dQ.producer_tail(producer_state_dQ) @cute.jit def split_wg( self, t: cute.Tensor, wg_idx: cutlass.Int32, num_wg: cutlass.Constexpr[int], ): reduced_shape = cute.product_each(t.shape) rank = len(reduced_shape) if const_expr(reduced_shape[1] > 1): assert rank >= 2, "Need rank >= 2 for t in split_wg" t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg)) coord = (None, (None, wg_idx)) + (None,) * (rank - 2) else: assert rank >= 3, "Need rank >= 3 for t in split_wg" if const_expr(rank == 3): t = cute.logical_divide( t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg) ) coord = ( None, None, (None, wg_idx), ) + (None,) * (rank - 3) else: t = cute.logical_divide( t, ( reduced_shape[0], reduced_shape[1], reduced_shape[2], reduced_shape[3] // num_wg, ), ) coord = ( None, None, None, (None, wg_idx), ) + (None,) * (rank - 4) return t[coord] @cute.jit def apply_score_mod( self, tSrS_t2r, thr_copy_t2r, thr_mma_S, batch_idx, head_idx, m_block, n_block, softmax_scale, seqlen_info, aux_tensors=None, fastdiv_mods=(None, None), ): """Apply forward score modification for SM100 backward pass.""" # In bwd, S is computed as K @ Q.T so dimensions are (tile_n, tile_m) cS = cute.make_identity_tensor((self.tile_n, self.tile_m)) cS = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS) tScS = thr_mma_S.partition_C(cS) tScS_idx = thr_copy_t2r.partition_D(tScS) apply_score_mod_inner( tSrS_t2r, tScS_idx, self.score_mod, batch_idx, head_idx, softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, transpose_indices=True, ) @cute.jit def apply_score_mod_bwd( self, grad_tensor, score_tensor, index_tensor, batch_idx, head_idx, softmax_scale, seqlen_info, aux_tensors=None, fastdiv_mods=(None, None), ): """Apply backward score modification (joint graph) for SM100.""" apply_score_mod_bwd_inner( grad_tensor, score_tensor, index_tensor, self.score_mod_bwd, batch_idx, head_idx, softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, transpose_indices=True, ) @cute.jit def compute_loop( self, thr_mma_S: cute.core.ThrMma, thr_mma_dP: cute.core.ThrMma, thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tStS: cute.Tensor, tdPtdP: cute.Tensor, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, sdS: cute.Tensor, sdS_xchg: cute.Tensor, pipeline_LSE: PipelineAsync, pipeline_dPsum: PipelineAsync, pipeline_S_P: PipelineAsync, pipeline_dS: PipelineAsync, pipeline_dKV: PipelineAsync, pipeline_dP: PipelineAsync, dS_cluster_empty_mbar_ptr: cute.Pointer, dS_cluster_full_mbar_ptr: cute.Pointer, dQaccum_empty_mbar_ptr: cute.Pointer, softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, sdV: Optional[cute.Tensor], sdK: Optional[cute.Tensor], mdV_tma_tensor: Optional[cute.Tensor], mdK_tma_tensor: Optional[cute.Tensor], tma_atom_dV: Optional[cute.CopyAtom], tma_atom_dK: Optional[cute.CopyAtom], tiled_copy_r2s_dKV: Optional[cute.TiledCopy], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, ): sLSE_2D = cute.make_tensor( sLSE.iterator, cute.make_layout( (self.tile_m, self.tile_n, self.Q_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) sdPsum_2D = cute.make_tensor( sdPsum.iterator, cute.make_layout( (self.tile_m, self.tile_n, self.dO_stage), stride=(1, 0, cute.round_up(self.tile_m, 64)), ), ) # if const_expr(self.SdP_swapAB): if const_expr(True): sLSE_2D = layout_utils.transpose_view(sLSE_2D) sdPsum_2D = layout_utils.transpose_view(sdPsum_2D) # tix: [128...384] 8 warps warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11 tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) # tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0]) dp_idx = tidx % 128 num_wg = len(self.compute_warp_ids) // 4 # 2 # wg_idx: # 0: [256...384] # 1: [128...256] tileP_f32_like = self.cta_tiler[1] // 32 * self.v_dtype.width # tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1) # tP overlap with tS tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2])) tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) # tdS overlap with tdP tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2])) tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1)) # 2-CTA assumes: repetiton should always be 32 & 16 tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32 ) tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) # tmem -> rmem thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx) tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1) tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP) tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1) t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1) # ((32, 1), 2, 1, 1, STAGE) tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D)) tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D)) # rmem -> tmem thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx) tScP_r2t = thr_copy_r2t.partition_S(tScP) tStP_r2t = thr_copy_r2t.partition_D(tStP) tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS) tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS) # rmem -> smem # This part is a bit iffy, we might be making a lot of assumptions here copy_atom_r2s = sm100_utils_basic.get_smem_store_op( LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r ) thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx) # We assume the swizzle (i.e. layout.inner) stays the same sdS_epi_layout = sm100_utils_basic.make_smem_layout_epi( self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1 ) sdS_layout = cute.slice_(sdS_epi_layout.outer, (None, None, 0)) # ((8,16), (64,2)) # Need to group into 1 mode to be compatible w thr_copy_r2s sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,)) sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout) tRS_sdS = thr_copy_r2s.partition_D(sdS_epi) if const_expr(self.use_2cta_instrs): sdS_xchg_epi = cute.make_tensor( cute.recast_ptr(sdS_xchg.iterator, sdS_epi_layout.inner), sdS_layout ) tRS_sdS_xchg = thr_copy_r2s.partition_D(sdS_xchg_epi) cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) dS_cluster_empty_phase = Int32(1) # 2-CTA: CTA 0 exchanges stage 1 (bottom half), CTA 1 exchanges stage 0 (top half) exchange_stage = cta_rank_in_cluster ^ 1 if const_expr(self.use_2cta_instrs) else Int32(0) consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 cutlass.pipeline.PipelineUserType.Consumer, 1 ) # consumer_phase_S_P_dP = Int32(0) producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1 cutlass.pipeline.PipelineUserType.Producer, 1 ) consumer_state_dKV = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 2 ) consumer_state_LSE = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) consumer_state_dPsum = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) mask = AttentionMaskCls(seqlen) n_block_for_cluster = n_block // self.cta_group_size # TODO: condition mask_seqlen mask_fn = partial( mask.apply_mask_sm100_transposed, tScS_t2r=tScS_t2r, t0ScS_t2r=t0ScS_t2r, n_block=n_block_for_cluster, mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, mask_mod=self.mask_mod, batch_idx=batch_idx, head_idx=head_idx, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) # prefetch_LSE = not self.is_causal prefetch_LSE = False # some tiles might be empty due to block sparsity if const_expr(self.use_block_sparsity): ( curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, loop_count, ) = get_block_sparse_iteration_info_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) else: process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) loop_count = m_block_max - m_block_min # Mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): if const_expr(self.use_block_sparsity): m_block, is_full_block = get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) m_block_oob = m_block >= m_block_max else: m_block = m_block_min + iter_idx m_block_oob = False is_full_block = False # Prefetch 1 stage of LSE pipeline_LSE.consumer_wait(consumer_state_LSE) tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) if const_expr(prefetch_LSE and not self.shuffle_LSE): cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r) pipeline_S_P.consumer_wait(consumer_state_S_P_dP) # pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP) #### TMEM->RMEM (Load S from TMEM) tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32) cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) if const_expr(self.tile_hdim == 192): # Signal S tmem load completion using pipeline_S_P when hdim 192 # dP is overlapped with S cute.arch.fence_view_async_tmem_load() with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) elif const_expr(self.use_2cta_instrs and self.tile_hdim <= 128): # Signal S tmem load completion using pipeline_dS when 2cta hdim 128 # dQ is overlapped with S if iter_idx > 0: cute.arch.fence_view_async_tmem_load() with cute.arch.elect_one(): pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() if const_expr(self.score_mod_bwd is not None): tSrS_pre = cute.make_fragment_like(tSrS_t2r) cute.autovec_copy(tSrS_t2r, tSrS_pre) if const_expr(self.score_mod is not None): # Apply score_mod FIRST -> matches forward self.apply_score_mod( tSrS_t2r, thr_copy_t2r, thr_mma_S, batch_idx, head_idx, m_block, n_block, softmax_scale, seqlen, aux_tensors, fastdiv_mods, ) #### APPLY MASK (after score_mod, matching forward pass order) check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q mask_fn( tSrS_t2r, m_block=m_block, is_full_block=is_full_block, check_m_boundary=check_m_boundary, ) num_stages = cute.size(tScS_t2r, mode=[1]) # --------------------------------------------- #### P = exp(S - LSE) # --------------------------------------------- lane_idx = cute.arch.lane_idx() tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64 tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype) for stage in cutlass.range_constexpr(num_stages): tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index] if const_expr(not self.shuffle_LSE): if const_expr(stage > 0 or not prefetch_LSE): cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r) tSrLSE = tSrLSE_s2r else: tSrLSE = tSsLSE_cur[lane_idx] for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2): if const_expr(not self.shuffle_LSE): lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1]) else: lse_pair = ( utils.shuffle_sync(tSrLSE, offset=2 * v), utils.shuffle_sync(tSrLSE, offset=2 * v + 1), ) tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = cute.arch.fma_packed_f32x2( ((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])), (softmax_scale_log2, softmax_scale_log2), (-lse_pair[0], -lse_pair[1]), ) tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True) tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True) utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0]) if const_expr(stage == 0): cute.arch.fence_view_async_tmem_load() # Without this barrier, we could have 1 warp writing to P in tmem while # another warp is still reading S from tmem. self.compute_sync_barrier.arrive_and_wait() cute.copy( thr_copy_r2t, tSrP_r2t_f32[None, stage, None, None], tStP_r2t[None, stage, None, None], ) cute.arch.fence_view_async_tmem_store() cute.arch.fence_view_async_shared() self.compute_sync_barrier.arrive_and_wait() if const_expr(not self.tile_hdim == 192): # Signal tmem store P completion with pipeline_S_P with cute.arch.elect_one(): pipeline_S_P.consumer_release(consumer_state_S_P_dP) # pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask) # Normally we'd need syncwarp here since only 1 thread will signal in # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_LSE.consumer_release(consumer_state_LSE) consumer_state_LSE.advance() # --------------------------------------------- # dS.T = P.T * (dP.T - D) # --------------------------------------------- pipeline_dPsum.consumer_wait(consumer_state_dPsum) pipeline_dP.consumer_wait(consumer_state_S_P_dP) # pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP) ### Now delayed to after loop # consumer_state_S_P_dP.advance() # consumer_phase_S_P_dP ^= 1 ##### dS.T = P.T * (dP.T - Psum) for stage in cutlass.range_constexpr(num_stages): tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32) cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r) cute.arch.fence_view_async_tmem_load() self.compute_sync_barrier.arrive_and_wait() tdPrdP_cur = tdPrdP_t2r[None, 0, 0] tSrS_cur = tSrS_t2r[None, stage, 0, 0] tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index] if const_expr(not self.shuffle_dPsum): tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32) cute.autovec_copy(tSsdPsum_cur, tSrdPsum) else: tSrdPsum = tSsdPsum_cur[lane_idx] for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2): if const_expr(not self.shuffle_dPsum): dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1]) else: dPsum_pair = ( utils.shuffle_sync(tSrdPsum, offset=2 * v), utils.shuffle_sync(tSrdPsum, offset=2 * v + 1), ) tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = ( quack.activation.sub_packed_f32x2( (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair ) ) tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = cute.arch.mul_packed_f32x2( (tSrS_cur[2 * v], tSrS_cur[2 * v + 1]), (tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), ) if const_expr(self.score_mod_bwd is not None): tSrS_pre_cur = tSrS_pre[None, stage, 0, 0] cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m)) cS_bwd = cute.domain_offset( (n_block * self.tile_n, m_block * self.tile_m), cS_bwd ) tScS_bwd = thr_mma_S.partition_C(cS_bwd) tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd) tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0] self.apply_score_mod_bwd( tdPrdP_cur, tSrS_pre_cur, tScS_idx_cur, batch_idx, head_idx, softmax_scale, seqlen, aux_tensors, fastdiv_mods, ) # Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True): kv_idx = tScS_idx_cur[i][0] tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i] tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype) utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt) if const_expr(stage == 0): pipeline_dS.producer_acquire(producer_state_dS) if const_expr(self.use_2cta_instrs): tdPrdS_xchg = cute.make_fragment_like(tdPrdS_cvt, self.ds_dtype) # RMEM->TMEM: always write to TMEM for MMA if const_expr(not self.use_smem_dS_for_mma_dK or self.use_2cta_instrs): tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32) cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0]) # RMEM->SMEM: For 2-CTA, keep exchange stage in registers, write non-exchange to sdS if const_expr(self.use_2cta_instrs): if exchange_stage == stage: cute.autovec_copy(tdPrdS_cvt, tdPrdS_xchg) else: cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) else: cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage]) if const_expr(not self.use_smem_dS_for_mma_dK): cute.arch.fence_view_async_tmem_store() if const_expr(self.use_2cta_instrs): # use pipeline_dP to signal tmem store of dS with cute.arch.elect_one(): pipeline_dP.consumer_release(consumer_state_S_P_dP) consumer_state_S_P_dP.advance() # After the loop: copy exchange registers to sdS_xchg buffer if const_expr(self.use_2cta_instrs): # when hdim 192, sdQaccum overlapped with sdS_xchg if const_expr(self.tile_hdim == 192): cute.arch.mbarrier_wait( dQaccum_empty_mbar_ptr, phase=producer_state_dS.phase ) cute.autovec_copy(tdPrdS_xchg, tRS_sdS_xchg[None, 0]) cute.arch.fence_view_async_shared() self.compute_sync_barrier.arrive_and_wait() # Normally we'd need syncwarp here since only 1 thread will signal in # consumer_release, but we already have the self.compute_sync_barrier before this pipeline_dPsum.consumer_release(consumer_state_dPsum) consumer_state_dPsum.advance() # when 2cta hdim 128, pipeline_dS also signals S tmem load completion so is deferred if const_expr(not (self.use_2cta_instrs and self.tile_hdim == 128)): with cute.arch.elect_one(): pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() # 2-CTA: DSMEM copy from sdS_xchg to peer's sdS buffer if const_expr(self.use_2cta_instrs): stage_copy_bytes = const_expr(self.tma_copy_bytes["dS"] // 2) stage_copy_elems = const_expr(stage_copy_bytes // (self.ds_dtype.width // 8)) if tidx == 0: peer_cta_rank_in_cluster = cta_rank_in_cluster ^ 1 smem_src_ptr = sdS_xchg.iterator # Destination is peer's sdS at our CTA's offset (exchange_stage position) smem_dst_ptr = sdS.iterator + cta_rank_in_cluster * stage_copy_elems cute.arch.mbarrier_arrive_and_expect_tx( dS_cluster_full_mbar_ptr, stage_copy_bytes, peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, ) copy_utils.cpasync_bulk_s2cluster( smem_src_ptr, smem_dst_ptr, dS_cluster_full_mbar_ptr, stage_copy_bytes, peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, ) # Final signal for dS smem store completion if const_expr(self.use_2cta_instrs and self.tile_hdim == 128): if process_tile: with cute.arch.elect_one(): pipeline_dS.producer_commit(producer_state_dS) producer_state_dS.advance() # Epilogue # Run epilogue if we processed any m_blocks for this n_block if process_tile: if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( dp_idx, warp_idx, batch_idx, head_idx, n_block, seqlen, thr_mma_dV, thr_mma_dK, tdVtdV, tdKtdK, mdV, mdK, pipeline_dKV, consumer_state_dKV, softmax_scale, ) else: thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx) #### STORE dV consumer_state_dKV = self.epilogue_dK_or_dV_tma( dp_idx, batch_idx, head_idx, n_block, seqlen, thr_mma_dV, tdVtdV, mdV_tma_tensor, sdV, tma_atom_dV, thr_copy_r2s_dKV, pipeline_dKV, consumer_state_dKV, None, # Don't scale int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdV_semaphore, "V", ) #### STORE dK consumer_state_dKV = self.epilogue_dK_or_dV_tma( dp_idx, batch_idx, head_idx, n_block, seqlen, thr_mma_dK, tdKtdK, mdK_tma_tensor, sdK, tma_atom_dK, thr_copy_r2s_dKV, pipeline_dKV, consumer_state_dKV, softmax_scale if const_expr(not self.dKV_postprocess) else None, int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, "K", ) # Zero dK/dV for empty tiles (local attention or block sparsity) # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile if const_expr(not self.dKV_postprocess): should_zero_dKV = False if const_expr(self.is_local or self.is_varlen_q): should_zero_dKV = m_block_min >= m_block_max if const_expr(self.use_block_sparsity): # For block sparsity, zero when no m_blocks contribute to this n_block if not process_tile: should_zero_dKV = True if should_zero_dKV: # For 2-CTA: use cluster-wide tile size (cta_group_size * tile_n) cluster_tile_n = self.tile_n * self.cta_group_size n_block_for_tile = n_block // self.cta_group_size gmem_tiled_copy_zero_dK = copy_utils.tiled_copy_2d( self.dk_dtype, math.gcd(64, self.tile_hdim), 128, # num_threads ) gmem_tiled_copy_zero_dV = copy_utils.tiled_copy_2d( self.dv_dtype, math.gcd(64, self.tile_hdimv), 128, # num_threads ) gmem_thr_copy_zero_dK = gmem_tiled_copy_zero_dK.get_slice(dp_idx) gmem_thr_copy_zero_dV = gmem_tiled_copy_zero_dV.get_slice(dp_idx) mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] gdK = cute.local_tile( mdK_cur, (cluster_tile_n, self.tile_hdim), (n_block_for_tile, 0) ) gdV = cute.local_tile( mdV_cur, (cluster_tile_n, self.tile_hdimv), (n_block_for_tile, 0) ) tdKgdK = gmem_thr_copy_zero_dK.partition_D(gdK) tdVgdV = gmem_thr_copy_zero_dV.partition_D(gdV) cdK = cute.make_identity_tensor((cluster_tile_n, self.tile_hdim)) cdV = cute.make_identity_tensor((cluster_tile_n, self.tile_hdimv)) tdKcdK = gmem_thr_copy_zero_dK.partition_D(cdK) tdVcdV = gmem_thr_copy_zero_dV.partition_D(cdV) assert cute.size(tdKgdK[None, 0, 0]) == cute.size(tdVgdV[None, 0, 0]) zero = cute.make_fragment_like(tdKgdK[None, 0, 0]) zero.fill(0.0) if tidx < 128: for i in cutlass.range_constexpr(tdKgdK.shape[1]): row_idx = tdKcdK[0, i, 0][0] if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile: for j in cutlass.range_constexpr(tdKgdK.shape[2]): cute.copy(gmem_tiled_copy_zero_dK, zero, tdKgdK[None, i, j]) else: for i in cutlass.range_constexpr(tdVgdV.shape[1]): row_idx = tdVcdV[0, i, 0][0] if row_idx < seqlen.seqlen_k - cluster_tile_n * n_block_for_tile: for j in cutlass.range_constexpr(tdVgdV.shape[2]): cute.copy(gmem_tiled_copy_zero_dV, zero, tdVgdV[None, i, j]) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @cute.jit def dQacc_reduce( self, mdQaccum: cute.Tensor, sdQaccum: cute.Tensor, thr_mma_dQ: cute.core.ThrMma, tdQtdQ: cute.Tensor, pipeline_dQ: PipelineAsync, dQaccum_empty_mbar_ptr: Optional[cute.Pointer], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, mdQ_semaphore: Optional[cute.Tensor], blocksparse_tensors: Optional[BlockSparseTensors] = None, ): num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) tidx = cute.arch.thread_idx()[0] % num_reduce_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids)) is_tma_warp = warp_idx == 0 cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) # TMEM -> RMEM tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol_t2r)), Float32 ) thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx) tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ) tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2])) tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape # For 2-CTA: reduce_stage = dQaccum_reduce_stage_t2r / cta_group_size expected_reduce_stages_t2r = self.dQaccum_reduce_stage_t2r // self.cta_group_size assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == expected_reduce_stages_t2r, ( "dQaccum t2r reduce stage mismatch" ) expected_reduce_stages = self.dQaccum_reduce_stage // self.cta_group_size # 2-CTA: CTA 0 -> (M/2, D) (stage 0, 1) & CTA 1 -> (M/2, D) (stage 2, 3) stage_offset = ( expected_reduce_stages * cta_rank_in_cluster if const_expr(self.use_2cta_instrs) else 0 ) thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d( self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width ).get_slice(tidx) tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum) read_flag = const_expr(not self.deterministic) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() dQ_consumer_state = pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, 1 ) dQ_tma_store_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.sdQaccum_stage ) while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx n_block_cta_group = n_block // self.cta_group_size # for 2cta seqlen = SeqlenInfoCls(batch_idx) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block_cta_group) if const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] else: mdQaccum_cur = cute.domain_offset( (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] ) gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,)) # (M * K / STAGE, STAGE, _) gdQaccum = cute.flat_divide( gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,) ) if const_expr(self.deterministic): mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192 delay_semaphore_release = not self.tile_hdim == 192 # some tiles might be empty due to block sparsity if const_expr(self.use_block_sparsity): ( curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, loop_count, ) = get_block_sparse_iteration_info_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) else: process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) loop_count = m_block_max - m_block_min # dQacc_reduce mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): if const_expr(self.use_block_sparsity): m_block, _ = get_m_block_from_iter_bwd( iter_idx, curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) if m_block_max > 0: m_block = cutlass.min(m_block, m_block_max - 1) else: m_block = m_block_min + iter_idx pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r) cute.arch.fence_view_async_tmem_load() cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() gdQaccum_cur = gdQaccum[None, None, m_block] tdQrdQ_shape = ( self.dQ_reduce_ncol, self.tile_hdim // self.cta_group_size // self.dQ_reduce_ncol, ) tdQrdQ = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_shape) for stage in cutlass.range_constexpr(cute.size(tdQrdQ, mode=[1])): smem_idx = dQ_tma_store_producer_state.index tdQsdQ_r2s = tdQsdQ[None, None, smem_idx] tdQrdQ_r2s = cute.make_tensor(tdQrdQ[None, stage].iterator, tdQsdQ_r2s.shape) cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_view_async_shared() # semaphore acquire if const_expr(self.deterministic and stage == 0): if const_expr(self.spt): _, n_block_max_for_m_block = block_info.get_n_block_min_max( seqlen, m_block ) lock_value = n_block_max_for_m_block - 1 - n_block_cta_group else: lock_value = n_block_cta_group barrier.wait_eq( mdQ_semaphore_cur[(m_block, None)].iterator, tidx, cta_rank_in_cluster, lock_value, ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory if is_tma_warp: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, smem_idx].iterator, gdQaccum_cur[None, stage + stage_offset].iterator, self.tma_copy_bytes["dQ"] // 1, ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() dQ_tma_store_producer_state.advance() # Directly add to gmem, much slower # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): # copy_utils.atomic_add_fp32x4( # tdQrdQ_r2s[4 * i], # tdQrdQ_r2s[4 * i + 1], # tdQrdQ_r2s[4 * i + 2], # tdQrdQ_r2s[4 * i + 3], # utils.elem_pointer(tdQgdQ, 4 * i), # ) # semaphore release for prior m_block if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: barrier.arrive_inc( mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, cta_rank_in_cluster, 1, ) if const_expr(self.tile_hdim == 192): if const_expr(self.sdQaccum_stage > 1): if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() with cute.arch.elect_one(): cute.arch.mbarrier_arrive(dQaccum_empty_mbar_ptr) # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic and not delay_semaphore_release): if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192): if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() barrier.arrive_inc( mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 ) if process_tile: if is_tma_warp: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() # final semaphore release if const_expr(self.deterministic and delay_semaphore_release): barrier.arrive_inc( mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, cta_rank_in_cluster, 1, ) if const_expr( self.deterministic and not self.spt and block_info.window_size_left is not None ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): barrier.arrive_inc( mdQ_semaphore_cur[(m_block, None)].iterator, tidx, cta_rank_in_cluster, 1 ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() if const_expr(not self.deterministic): cute.arch.cp_async_bulk_wait_group(0, read=True) @cute.jit def epilogue_dKV( self, tidx: Int32, warp_idx: Int32, batch_idx: Int32, head_idx: Int32, n_block: Int32, seqlen, thr_mma_dV: cute.core.ThrMma, thr_mma_dK: cute.core.ThrMma, tdVtdV: cute.Tensor, tdKtdK: cute.Tensor, mdV: cute.Tensor, mdK: cute.Tensor, pipeline_dKV: PipelineAsync, consumer_state_dKV: cutlass.pipeline.PipelineState, softmax_scale: Float32, ): wg_idx = ( cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids)) ) // 128 num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128 assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA" mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx] mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx] tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) # dV pipeline_dKV.consumer_wait(consumer_state_dKV) tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV) thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx) tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV) tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg) cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1])) tdVcdV = thr_mma_dV.partition_C(cdV) tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout) tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor) tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg) tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32) cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r) cute.arch.fence_view_async_tmem_load() universal_copy_bits = 128 atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dv_dtype, num_bits_per_copy=universal_copy_bits, ) tiled_gmem_store_dV = cute.make_tiled_copy( atom_universal_copy, layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled, tiler_mn=tiled_tmem_ld_dV.tiler_mn, ) tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype) for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])): dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load() tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype)) gdV = cute.local_tile(mdV_cur, (self.mma_tiler_pdo[0], self.tile_hdimv), (None, 0)) gdV_tile = gdV[None, None, n_block // self.cta_group_size] tdVgdV = thr_mma_dV.partition_C(gdV_tile) tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV) tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg) if tidx < seqlen.seqlen_k - self.tile_n * n_block: cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g) cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dKV.consumer_release(consumer_state_dKV) consumer_state_dKV.advance() # dK pipeline_dKV.consumer_wait(consumer_state_dKV) tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK) thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx) tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK) tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg) cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1])) tdKcdK = thr_mma_dK.partition_C(cdK) tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout) tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor) tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg) tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32) cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r) cute.arch.fence_view_async_tmem_load() universal_copy_bits = 128 atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dk_dtype, num_bits_per_copy=universal_copy_bits, ) tiled_gmem_store_dK = cute.make_tiled_copy( atom_universal_copy, layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled, tiler_mn=tiled_tmem_ld_dK.tiler_mn, ) tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype) for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])): dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype)) gdK = cute.local_tile(mdK_cur, (self.mma_tiler_dsq[0], self.tile_hdim), (None, 0)) gdK_tile = gdK[None, None, n_block // self.cta_group_size] tdKgdK = thr_mma_dK.partition_C(gdK_tile) tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK) tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg) if tidx < seqlen.seqlen_k - self.tile_n * n_block: cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g) cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dKV.consumer_release(consumer_state_dKV) return consumer_state_dKV @cute.jit def epilogue_dK_or_dV_tma( self, tidx: Int32, batch_idx: Int32, head_idx: Int32, n_block: Int32, seqlen, thr_mma: cute.core.ThrMma, tdKVtdKV: cute.Tensor, mdKV: cute.Tensor, sdKV: cute.Tensor, tma_atom_dKV: cute.CopyAtom, thr_copy_r2s_dKV: cute.TiledCopy, pipeline_dKV: PipelineAsync, consumer_state_dKV: cutlass.pipeline.PipelineState, scale: Optional[Float32], barrier_id: Int32, mdKV_semaphore: Optional[cute.Tensor], K_or_V: cutlass.Constexpr[str], ) -> cutlass.pipeline.PipelineState: assert K_or_V in ("K", "V") tile_hdim = self.tile_hdim if const_expr(K_or_V == "K") else self.tile_hdimv dtype = self.dk_dtype if const_expr(K_or_V == "K") else self.dv_dtype epi_tile = self.sdK_epi_tile if const_expr(K_or_V == "K") else self.sdV_epi_tile flat_epi_tile = ( self.sdK_flat_epi_tile if const_expr(K_or_V == "K") else self.sdV_flat_epi_tile ) num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids) wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128 num_wg = num_compute_threads // 128 leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0 cta_group_tile_n = const_expr(self.tile_n * self.cta_group_size) if const_expr(not self.dKV_postprocess): sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 else: sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(not self.dKV_postprocess): assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path" mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n, tile_hdim), (n_block, 0) ) # (tile_n, hdim) - per CTA gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) gdKV_epi = cute.local_tile( gdKV, epi_tile, (0, None) ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: # n_block_group = n_block // self.cta_group_size if const_expr(not seqlen.has_cu_seqlens_k): mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) else: mdKV_cur = cute.domain_offset( (seqlen.padded_offset_k * tile_hdim,), mdKV[None, head_idx_kv] ) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n * tile_hdim,), (n_block,) ) # (tile_n * hdim) gdKV = cute.logical_divide(gdKV_p, (self.tile_n * tile_hdim // num_wg,))[ ((None, wg_idx),) ] # (tile_n * hdim / 2) gdKV_epi = cute.flat_divide( gdKV, (flat_epi_tile,) ) # (tile_n * hdim / 2 / epi_stage, epi_stage) deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 if const_expr(deterministic_KV): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(not self.dKV_postprocess): tdKVsdKV, tdKVgdKV = cpasync.tma_partition( tma_atom_dKV, 0, # no multicast cute.make_layout(1), cute.group_modes(sdKV, 0, 2), cute.group_modes(gdKV_epi, 0, 2), ) # (TMA) and (TMA, EPI_STAGE) assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" num_epi_stages = cute.size(tdKVgdKV.shape[1]) if const_expr(K_or_V == "K"): assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong (K)" else: assert num_epi_stages == self.num_epi_stages_v, "Epi stage calculation is wrong (V)" else: num_epi_stages = ( self.num_epi_stages if const_expr(K_or_V == "K") else self.num_epi_stages_v ) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dK_reduce_ncol)), Float32 ) read_flag = const_expr(not deterministic_KV) pipeline_dKV.consumer_wait(consumer_state_dKV) # semaphore acquire if const_expr(deterministic_KV): barrier.wait_eq( mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead ) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) for epi_stage in cutlass.range_constexpr(num_epi_stages): # TMEM -> RMEM -- setup thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx) tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV) tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage] cdKV = cute.make_identity_tensor((cta_group_tile_n, tile_hdim)) tdKVcdKV = thr_mma.partition_C(cdKV) tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV) tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0] if const_expr(num_epi_stages > 1): tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage] tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32) assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, ( "RMEM<->TMEM fragment size mismatch" ) # TMEM -> RMEM -- copy and fence cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r) cute.arch.fence_view_async_tmem_load() # RMEM -- scale and convert if const_expr(scale is not None): for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True): tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = cute.arch.mul_packed_f32x2( (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) ) tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, dtype) # (32 columns) tdKVrdKV.store(tdKVrdKV_t2r.load().to(dtype)) # RMEM -> SMEM -- copy, fence and barrier tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape) cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s) cute.arch.fence_view_async_shared() cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) # SMEM -> GMEM if leader_warp: if const_expr(not self.dKV_postprocess): cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage]) else: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdKV.iterator, gdKV_epi[None, epi_stage].iterator, self.tma_copy_bytes["dKacc"], ) if const_expr(epi_stage < num_epi_stages - 1): cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) cute.arch.barrier_arrive( barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE ) # Barrier since all warps need to wait for SMEM to be freed cute.arch.fence_view_async_shared() cute.arch.barrier( barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE ) # semaphore release # NOTE: arrive_inc calls red_release which issues membar if const_expr(deterministic_KV): if leader_warp: cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=read_flag) cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128) barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1) cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_dKV.consumer_release(consumer_state_dKV) consumer_state_dKV.advance() return consumer_state_dKV ================================================ FILE: flash_attn/cute/flash_bwd_sm120.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # SM120 (Blackwell GeForce / DGX Spark) backward pass. # # SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has # a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses # FlashAttentionBackwardSm80 and overrides the SMEM capacity check accordingly. import cutlass import cutlass.utils as utils_basic from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): @staticmethod def can_implement( dtype, head_dim, head_dim_v, m_block_size, n_block_size, num_stages_Q, num_stages_dO, num_threads, is_causal, V_in_regs=False, ) -> bool: """Check if the kernel can be implemented on SM120. Same logic as SM80 but uses SM120's shared memory capacity (99 KB). """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if head_dim_v % 8 != 0: return False if n_block_size % 16 != 0: return False if num_threads % 32 != 0: return False # Shared memory usage: Q tile + dO tile + K tile + V tile smem_usage_Q = m_block_size * head_dim * num_stages_Q * 2 smem_usage_dO = m_block_size * head_dim_v * num_stages_dO * 2 smem_usage_K = n_block_size * head_dim * 2 smem_usage_V = n_block_size * head_dim_v * 2 smem_usage_QV = ( (smem_usage_Q + smem_usage_V) if not V_in_regs else max(smem_usage_Q, smem_usage_V) ) smem_usage = smem_usage_QV + smem_usage_dO + smem_usage_K # SM120 has 99 KB shared memory (vs 163 KB on SM80) smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") if smem_usage > smem_capacity: return False return True ================================================ FILE: flash_attn/cute/flash_bwd_sm90.py ================================================ import math from typing import Callable, Optional, Type from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass.cute.nvgpu import cpasync, warpgroup from cutlass.cute import FastDivmodDivisor from cutlass import Float32, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from quack import copy_utils from quack import layout_utils from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute import pipeline from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) from flash_attn.cute import barrier from flash_attn.cute.named_barrier import NamedBarrierBwd from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, dQaccum_store_block_sparse_bwd_sm90, ) class FlashAttentionBackwardSm90: arch = 90 def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, is_local: bool = False, deterministic: bool = False, tile_m: int = 64, tile_n: int = 128, Q_stage: int = 2, dO_stage: int = 2, PdS_stage: int = 2, SdP_swapAB: bool = False, dKV_swapAB: bool = False, dQ_swapAB: bool = False, AtomLayoutMSdP: int = 1, AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 1, num_threads: int = 384, V_in_regs: bool = False, score_mod: cutlass.Constexpr | None = None, score_mod_bwd: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, subtile_factor: cutlass.Constexpr[int] = 1, dQ_single_wg: bool = False, ): self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication self.check_hdim_oob = head_dim != self.tile_hdim self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal self.is_local = is_local self.deterministic = deterministic self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads self.Q_stage = Q_stage self.dO_stage = dO_stage self.PdS_stage = PdS_stage assert self.dO_stage in [1, self.Q_stage] assert self.PdS_stage in [1, self.Q_stage] self.SdP_swapAB = SdP_swapAB self.dKV_swapAB = dKV_swapAB self.dQ_swapAB = dQ_swapAB self.AtomLayoutMSdP = AtomLayoutMSdP self.AtomLayoutNdKV = AtomLayoutNdKV self.AtomLayoutMdQ = AtomLayoutMdQ self.num_wg_mma = (self.num_threads // 128) - 1 self.mma_dkv_is_rs = ( AtomLayoutMSdP == 1 and AtomLayoutNdKV == self.num_wg_mma and SdP_swapAB and not dKV_swapAB ) self.V_in_regs = V_in_regs # May be overridden in __call__ for varlen inputs. if qhead_per_kvhead > 1: assert self.same_hdim_kv, "GQA backward requires head_dim == head_dim_v" assert self.num_wg_mma == 2, "GQA backward assumes 2 warp groups" # These are tuned for speed # Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share # them and then shuffle to get the value whenever we need? This can reduce register # pressure when SdP_swapAB, where each thread needs to keep statistics for (kBlockM / 4) # rows. If !SdP_swapAB, each thread only needs to keep statistics for 2 rows. self.shuffle_LSE = self.SdP_swapAB and self.tile_hdim <= 64 self.shuffle_dPsum = self.SdP_swapAB and self.tile_hdim <= 64 self.buffer_align_bytes = 1024 self.score_mod = score_mod self.score_mod_bwd = score_mod_bwd self.mask_mod = mask_mod self.has_aux_tensors = has_aux_tensors self.subtile_factor = subtile_factor if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: self.vec_size: cutlass.Constexpr = 4 self.qk_acc_dtype = Float32 # dQ_single_wg: WG0 computes the full dQ GEMM, WG1 skips it. # Only valid for 2 MMA warp groups. # Credit: Ben Spector if dQ_single_wg: assert self.num_wg_mma == 2, "dQ_single_wg only supports 2 warp groups" self.num_wg_dQ = 1 if dQ_single_wg else self.num_wg_mma @staticmethod def can_implement( dtype, head_dim, head_dim_v, tile_m, tile_n, Q_stage, num_threads, V_in_regs=False, ) -> bool: if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if head_dim_v % 8 != 0: return False if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False if (tile_m * 2) % num_threads != 0: return False return True def _check_type( self, mQ_type: Type[cutlass.Numeric], mK_type: Type[cutlass.Numeric], mV_type: Type[cutlass.Numeric], mdO_type: Type[cutlass.Numeric], mLSE_type: Type[cutlass.Numeric], mdPsum_type: Type[cutlass.Numeric], mdQaccum_type: Type[cutlass.Numeric], mdK_type: Type[cutlass.Numeric], mdV_type: Type[cutlass.Numeric], ): # Get the data type and check if it is fp16 or bf16 if const_expr(not (mQ_type == mK_type == mV_type == mdO_type)): raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mLSE_type not in [Float32]): raise TypeError("LSE tensor must be Float32") if const_expr(mdPsum_type not in [Float32]): raise TypeError("dPsum tensor must be Float32") if const_expr(mdQaccum_type not in [Float32]): raise TypeError("dQaccum tensor must be Float32") if const_expr(self.qhead_per_kvhead == 1): if const_expr(not (mdK_type == mdV_type == mQ_type)): raise TypeError("mdK and mdV tensors must have the same data type as mQ") else: if const_expr(not (mdK_type == mdV_type == Float32)): raise TypeError("mdKaccum and mdVaccum tensors must have the data type Float32") assert mQ_type == self.dtype def _setup_attributes(self): # We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. # Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. # The M dimension (tile_m) doesn't matter for the layout, only the K dimension wg_d_dKV = self.num_wg_mma // self.AtomLayoutNdKV self.sQ_layout, self.sdO_layout = [ # Need to set major_mode_size (mms) to accommodate Q and Q.T sm90_utils.make_smem_layout(self.dtype, LayoutEnum.ROW_MAJOR, shape, stage, mms) for shape, stage, mms in [ ((self.tile_m, self.tile_hdim), self.Q_stage, self.tile_hdim // wg_d_dKV), ((self.tile_m, self.tile_hdimv), self.dO_stage, self.tile_hdim // wg_d_dKV), ] ] wg_d_dQ = self.num_wg_dQ // self.AtomLayoutMdQ # Accomodate both K and K.T self.sK_layout = sm90_utils.make_smem_layout( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdim), stage=None, major_mode_size=self.tile_hdim // wg_d_dQ, ) # There's only V, no V.T, so layout is normal self.sV_layout = sm90_utils.make_smem_layout( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_hdimv), None ) # Accomodate both S and S.T wg_n_SdP = self.num_wg_mma // self.AtomLayoutMSdP wg_n_dKV = self.AtomLayoutNdKV self.sPdS_layout = sm90_utils.make_smem_layout( self.dtype, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n), stage=self.PdS_stage, major_mode_size=math.gcd(self.tile_n // wg_n_SdP, self.tile_n // wg_n_dKV), ) self.sdQaccum_layout = cute.make_layout( (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) ) # dQaccum R->S self.r2s_tiled_copy_dQaccum = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), # thr_layout cute.make_layout((self.num_threads_per_warp_group, self.num_wg_dQ)), cute.make_layout(128 // Float32.width), # val_layout ) # dKVaccum for GQA epilogue - reuses sV+sK memory recast as f32 # TODO: assert that sVaccum and sKaccum don't overflow smem def _get_tiled_mma(self): maybe_swap_mn = lambda shape, swap: (shape[1], shape[0], *shape[2:]) if swap else shape # S = Q @ K.T, dP = dO @ V.T atom_layout_SdP = (self.AtomLayoutMSdP, self.num_wg_mma // self.AtomLayoutMSdP, 1) tiler_mn_SdP = (self.tile_m // atom_layout_SdP[0], self.tile_n // atom_layout_SdP[1]) tiled_mma_SdP = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, atom_layout_mnk=maybe_swap_mn(atom_layout_SdP, self.SdP_swapAB), tiler_mn=(64, tiler_mn_SdP[1] if not self.SdP_swapAB else tiler_mn_SdP[0]), ) # dV = P.T @ dO, dK = dS.T @ Q atom_layout_dKV = (self.AtomLayoutNdKV, self.num_wg_mma // self.AtomLayoutNdKV, 1) tiler_mn_dK = (self.tile_n // atom_layout_dKV[0], self.tile_hdim // atom_layout_dKV[1]) tiler_mn_dV = (self.tile_n // atom_layout_dKV[0], self.tile_hdimv // atom_layout_dKV[1]) tiled_mma_dK, tiled_mma_dV = [ sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.MN if not self.mma_dkv_is_rs else warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, atom_layout_mnk=maybe_swap_mn(atom_layout_dKV, self.dKV_swapAB), tiler_mn=(64, tiler_mn_d[1] if not self.dKV_swapAB else tiler_mn_d[0]), a_source=warpgroup.OperandSource.RMEM if self.mma_dkv_is_rs else warpgroup.OperandSource.SMEM, ) for tiler_mn_d in (tiler_mn_dK, tiler_mn_dV) ] # dQ = dS @ K assert self.num_wg_dQ % self.AtomLayoutMdQ == 0 atom_layout_dQ = (self.AtomLayoutMdQ, self.num_wg_dQ // self.AtomLayoutMdQ, 1) tiler_mn_dQ = (self.tile_m // atom_layout_dQ[0], self.tile_hdim // atom_layout_dQ[1]) tiled_mma_dQ = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K if not self.dQ_swapAB else warpgroup.OperandMajorMode.MN, warpgroup.OperandMajorMode.MN if not self.dQ_swapAB else warpgroup.OperandMajorMode.K, Float32, atom_layout_mnk=maybe_swap_mn(atom_layout_dQ, self.dQ_swapAB), tiler_mn=(64, tiler_mn_dQ[1] if not self.dQ_swapAB else tiler_mn_dQ[0]), ) return tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct, sdO_struct, sdQaccum_struct = [ cute.struct.Align[cute.struct.MemRange[t, cute.cosize(layout)], self.buffer_align_bytes] for (layout, t) in [ (self.sQ_layout, self.dtype), (self.sK_layout, self.dtype), (self.sV_layout, self.dtype), (self.sdO_layout, self.dtype), (self.sdQaccum_layout, Float32), ] ] cosize_sdS = cute.cosize(self.sPdS_layout) cosize_sP = cute.cosize(self.sPdS_layout) if const_expr(not self.mma_dkv_is_rs) else 0 sLSE_struct = cute.struct.Align[ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.Q_stage], 128 ] sdPsum_struct = cute.struct.Align[ cute.struct.MemRange[Float32, cute.round_up(self.tile_m, 64) * self.dO_stage], 128 ] @cute.struct class SharedStorageQKV: mbar_ptr_Q: cute.struct.MemRange[cutlass.Int64, self.Q_stage * 2] mbar_ptr_dO: cute.struct.MemRange[cutlass.Int64, self.dO_stage * 2] sLSE: sLSE_struct sdPsum: sdPsum_struct sQ: sQ_struct sV: sV_struct sK: sK_struct sdO: sdO_struct sP: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] sdS: cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sdS], 1024] sdQaccum: sdQaccum_struct return SharedStorageQKV @cute.jit def __call__( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, softmax_scale: Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # For GQA (qhead_per_kvhead > 1), multiple Q heads accumulate into the same dK/dV, # so we need the float32 accum path + postprocess. # For varlen_k with qhead_per_kvhead == 1, we use ragged TMA tensors. self.varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None self._check_type( *( t.element_type if t is not None else None for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) ) ) self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV = [ assume_tensor_aligned(t) for t in (mQ, mK, mV, mdO, mLSE, mdPsum, mdQaccum, mdK, mdV) ] # Non-varlen inputs are (b, s, n, h), varlen inputs are (s, n, h). # We convert both to a seqlen-major view with head-dim second. # Each tensor may have different rank when Q is padded (seqused_q) but K/V are unpadded (cu_seqlens_k). def _qkv_transpose(t): return layout_utils.select(t, [1, 3, 2, 0] if cute.rank(t.shape) == 4 else [0, 2, 1]) mQ, mK, mV, mdO = [_qkv_transpose(t) for t in (mQ, mK, mV, mdO)] if const_expr(self.qhead_per_kvhead == 1): mdK, mdV = [_qkv_transpose(t) for t in (mdK, mdV)] else: # Accum tensors are (b, n, s*h) for non-varlen and (n, s*h) for varlen. accum_transpose = [2, 1, 0] if cute.rank(mdK.shape) == 3 else [1, 0] mdK, mdV = [layout_utils.select(t, accum_transpose) for t in (mdK, mdV)] # Non-varlen stats are (b, n, s), varlen stats are (n, s). LSE_dPsum_dQaccum_transpose = [2, 1, 0] if cute.rank(mLSE.shape) == 3 else [1, 0] mLSE, mdPsum, mdQaccum = [ layout_utils.select(t, LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) ] tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ = self._get_tiled_mma() # (batch, num_head, num_m_blocks, cluster_size) -> (num_m_blocks, cluster_size, num_head, batch) if const_expr(self.deterministic): assert mdQ_semaphore is not None mdQ_semaphore = layout_utils.select(mdQ_semaphore, mode=[2, 3, 1, 0]) self.num_mma_threads = tiled_mma_SdP.size assert self.num_mma_threads + 128 == self.num_threads self.num_threads_per_warp_group = 128 self.num_producer_threads = 32 REG_LIMIT = 504 if self.num_wg_mma == 2 else 512 if const_expr(self.num_wg_mma == 2): if const_expr(self.num_wg_dQ == 1): self.num_mma_regs_wg0 = 256 self.num_mma_regs_wg1 = 224 else: self.num_mma_regs_wg0 = 240 self.num_mma_regs_wg1 = 240 self.num_mma_regs = self.num_mma_regs_wg0 # for backward compat self.num_producer_regs = 24 assert ( self.num_mma_regs_wg0 + self.num_mma_regs_wg1 + self.num_producer_regs <= REG_LIMIT ) else: # 3 warp groups self.num_mma_regs_wg0 = 160 self.num_mma_regs_wg1 = 160 self.num_mma_regs = 160 self.num_producer_regs = 32 assert self.num_mma_regs_wg0 * self.num_wg_mma + self.num_producer_regs <= REG_LIMIT self._setup_attributes() SharedStorage = self._get_shared_storage_cls() self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) for name, mX, layout in [ ("Q", mQ, self.sQ_layout), ("K", mK, self.sK_layout), ("V", mV, self.sV_layout), ("dO", mdO, self.sdO_layout), ] } self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8 self.tma_copy_bytes["dQ"] = ( self.tile_m * self.tile_hdim * Float32.width // 8 // self.num_wg_dQ ) self.tma_copy_bytes["dKacc"] = self.tile_n * self.tile_hdim * Float32.width // 8 self.tma_copy_bytes["dVacc"] = self.tile_n * self.tile_hdimv * Float32.width // 8 tma_atom_Q, tma_tensor_Q = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mQ, cute.select(self.sQ_layout, mode=[0, 1]), (self.tile_m, self.tile_hdim), ) tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mK, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mV, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), ) tma_atom_dO, tma_tensor_dO = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileG2SOp(), mdO, cute.select(self.sdO_layout, mode=[0, 1]), (self.tile_m, self.tile_hdimv), ) if const_expr(self.qhead_per_kvhead == 1): mdK_tma = ( copy_utils.create_ragged_tensor_for_tma(mdK, ragged_dim=0, ptr_shift=True) if self.varlen_k else mdK ) mdV_tma = ( copy_utils.create_ragged_tensor_for_tma(mdV, ragged_dim=0, ptr_shift=True) if self.varlen_k else mdV ) tma_atom_dK, tma_tensor_dK = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), mdK_tma, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), ) tma_atom_dV, tma_tensor_dV = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), mdV_tma, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), ) else: tma_atom_dK = tma_atom_dV = tma_tensor_dK = tma_tensor_dV = None if const_expr(mCuSeqlensK is not None or mSeqUsedK is not None): TileScheduler = SingleTileVarlenScheduler elif const_expr(self.deterministic): TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler self.spt = (self.is_causal or self.is_local) and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.tile_n), cute.size(mQ.shape[2]), cute.size(mK.shape[3]) if const_expr(mCuSeqlensK is None) else cute.size(mCuSeqlensK.shape[0] - 1), # num_batch 1, # num_splits cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k mQ.shape[1], # headdim mV.shape[1], # headdim_v total_q=cute.size(mK.shape[0]) if const_expr(mCuSeqlensK is not None) else cute.size(mK.shape[0]) * cute.size(mK.shape[3]), tile_shape_mn=(self.tile_n, self.tile_m), # Swapping the role of Q & K mCuSeqlensQ=mCuSeqlensK, mSeqUsedQ=mSeqUsedK, qhead_per_kvhead_packgqa=1, element_size=self.dtype.width // 8, is_persistent=False, lpt=self.spt, head_swizzle=self.deterministic, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) LOG2_E = math.log2(math.e) if const_expr(self.score_mod is None): softmax_scale_log2 = softmax_scale * LOG2_E else: softmax_scale_log2 = LOG2_E fastdiv_mods = None if const_expr(aux_tensors is not None): seqlen_q = cute.size(mQ.shape[0]) seqlen_k = cute.size(mK.shape[0]) seqlen_q_divmod = FastDivmodDivisor(seqlen_q) seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) qhead_per_kvhead_divmod = None if const_expr(self.qhead_per_kvhead > 1): qhead_per_kvhead_divmod = FastDivmodDivisor(self.qhead_per_kvhead) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) if const_expr(window_size_left is not None): window_size_left = Int32(window_size_left) if const_expr(window_size_right is not None): window_size_right = Int32(window_size_right) self.kernel( tma_tensor_Q, tma_tensor_K, tma_tensor_V, tma_tensor_dO, tma_tensor_dK if const_expr(self.qhead_per_kvhead == 1) else mdK, tma_tensor_dV if const_expr(self.qhead_per_kvhead == 1) else mdV, tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV, mLSE, mdPsum, mdQaccum, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, self.sQ_layout, self.sK_layout, self.sV_layout, self.sPdS_layout, self.sdO_layout, self.sdQaccum_layout, self.r2s_tiled_copy_dQaccum, tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ, softmax_scale_log2, softmax_scale, tile_sched_params, TileScheduler, SharedStorage, aux_tensors, fastdiv_mods, blocksparse_tensors, qhead_per_kvhead_divmod, mdQ_semaphore, window_size_left, window_size_right, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, min_blocks_per_mp=1, use_pdl=True, ) @cute.kernel def kernel( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mdK: cute.Tensor, mdV: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, mLSE: cute.Tensor, mdPsum: cute.Tensor, mdQaccum: cute.Tensor, mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sPdS_layout: cute.ComposedLayout, sdO_layout: cute.ComposedLayout, sdQaccum_layout: cute.Layout, r2s_tiled_copy_dQaccum: cute.TiledCopy, tiled_mma_SdP: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, softmax_scale_log2, softmax_scale, tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, mdQ_semaphore: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # prefetch TMA descriptors if warp_idx == 0: for atom in [tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, tma_atom_dK, tma_atom_dV]: if const_expr(atom is not None): cpasync.prefetch_descriptor(atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) pipeline_producer_group = cutlass.pipeline.CooperativeGroup(cutlass.pipeline.Agent.Thread) pipeline_consumer_group = cutlass.pipeline.CooperativeGroup( cutlass.pipeline.Agent.Thread, self.num_mma_threads // cute.arch.WARP_SIZE ) pipeline_Q = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_Q.data_ptr(), num_stages=self.Q_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["Q"] + self.tma_copy_bytes["LSE"], defer_sync=True, ) pipeline_dO = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_dO.data_ptr(), num_stages=self.dO_stage, producer_group=pipeline_producer_group, consumer_group=pipeline_consumer_group, tx_count=self.tma_copy_bytes["dO"] + self.tma_copy_bytes["dPsum"], defer_sync=False, ) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sdO = storage.sdO.get_tensor(sdO_layout.outer, swizzle=sdO_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) sP = None if const_expr(not self.mma_dkv_is_rs): sP = storage.sP.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sdS = storage.sdS.get_tensor(sPdS_layout.outer, swizzle=sPdS_layout.inner) sLSE = storage.sLSE.get_tensor( cute.make_layout( (self.tile_m, self.Q_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) ) sdPsum = storage.sdPsum.get_tensor( cute.make_layout( (self.tile_m, self.dO_stage), stride=(1, cute.round_up(self.tile_m, 64)), ) ) sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout) block_info = BlockInfo( self.tile_m, self.tile_n, self.is_causal, self.is_local, False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, tile_m=self.tile_m, tile_n=self.tile_n, ) AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, window_size_left=window_size_left, window_size_right=window_size_right, swap_AB=self.SdP_swapAB, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) if warp_idx < 4: cute.arch.setmaxregister_decrease(self.num_producer_regs) if warp_idx == 0: self.load( mQ, mK, mV, mdO, mLSE, mdPsum, sQ, sK, sV, sdO, sLSE, sdPsum, tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_dO, pipeline_Q, pipeline_dO, block_info, SeqlenInfoCls, TileSchedulerCls, blocksparse_tensors, qhead_per_kvhead_divmod, ) if warp_idx == 1: self.dQaccum_store( mdQaccum, sdQaccum, block_info, TileSchedulerCls, SeqlenInfoCls, blocksparse_tensors, mdQ_semaphore, ) else: tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 mma_args = ( tiled_mma_SdP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ, mdK, mdV, mdQaccum, sQ, sK, sV, sdO, sP, sdS, sLSE, sdPsum, sdQaccum, pipeline_Q, pipeline_dO, tidx, tma_atom_dK, tma_atom_dV, r2s_tiled_copy_dQaccum, softmax_scale_log2, softmax_scale, block_info, SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, aux_tensors, fastdiv_mods, blocksparse_tensors, qhead_per_kvhead_divmod, ) if const_expr(self.num_wg_dQ == self.num_wg_mma): # Both WGs compute dQ cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) self.mma(*mma_args, is_dQ_wg=True) else: # WG0 computes dQ, WG1 skips it warp_idx_in_mma = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - 4 if warp_idx_in_mma < 4: cute.arch.setmaxregister_increase(self.num_mma_regs_wg0) self.mma(*mma_args, is_dQ_wg=True) else: cute.arch.setmaxregister_increase(self.num_mma_regs_wg1) self.mma(*mma_args, is_dQ_wg=False) @cute.jit def load( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mdO: cute.Tensor, mLSE: cute.Tensor, mdPsum: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, sdO: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, tma_atom_dO: cute.CopyAtom, pipeline_Q: cutlass.pipeline.PipelineAsync, pipeline_dO: cutlass.pipeline.PipelineAsync, block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 if warp_idx_in_wg == 0: producer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.Q_stage ) producer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Producer, self.dO_stage ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) head_idx_kv = ( head_idx if const_expr(self.qhead_per_kvhead == 1) else head_idx // qhead_per_kvhead_divmod ) mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[ None, head_idx ] mdO_cur = seqlen.offset_batch_Q(mdO, batch_idx, dim=3)[None, None, head_idx] mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[ None, head_idx ] gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (None, 0)) gdO = cute.local_tile(mdO_cur, (self.tile_m, self.tile_hdimv), (None, 0)) gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,)) gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,)) load_K, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), gK, sK, single_stage=True ) load_V, _, _ = copy_utils.tma_get_copy_fn( tma_atom_V, 0, cute.make_layout(1), gV, sV, single_stage=True ) load_Q, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), gQ, sQ ) load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q) load_dO, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dO, 0, cute.make_layout(1), gdO, sdO ) load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO) load_LSE = copy_utils.cpasync_bulk_get_copy_fn(gLSE, sLSE) load_LSE = copy_utils.tma_producer_copy_fn(load_LSE, pipeline_Q) load_dPsum = copy_utils.cpasync_bulk_get_copy_fn(gdPsum, sdPsum) load_dPsum = copy_utils.tma_producer_copy_fn(load_dPsum, pipeline_dO) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): total_m_block_cnt = m_block_max - m_block_min process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = total_m_block_cnt > Int32(0) if process_tile: if const_expr(not self.use_block_sparsity): first_m_block = m_block_min pipeline_Q.producer_acquire( producer_state_Q, extra_tx_count=self.tma_copy_bytes["K"] ) load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q)) load_Q(first_m_block, producer_state=producer_state_Q) # Wait for bwd preprocess to finish writing LSE and dPsum cute.arch.griddepcontrol_wait() load_LSE(first_m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO if const_expr(self.Q_stage != self.dO_stage) else producer_state_Q ) pipeline_dO.producer_acquire( producer_state_dO_cur, extra_tx_count=self.tma_copy_bytes["V"] ) load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_cur)) load_dO(first_m_block, producer_state=producer_state_dO_cur) load_dPsum(first_m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() producer_state_dO.advance() for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): pipeline_Q.producer_acquire(producer_state_Q) load_Q(m_block, producer_state=producer_state_Q) load_LSE(m_block, producer_state=producer_state_Q) producer_state_dO_cur = ( producer_state_dO if const_expr(self.Q_stage != self.dO_stage) else producer_state_Q ) pipeline_dO.producer_acquire(producer_state_dO_cur) load_dO(m_block, producer_state=producer_state_dO_cur) load_dPsum(m_block, producer_state=producer_state_dO_cur) producer_state_Q.advance() producer_state_dO.advance() else: producer_state_Q, producer_state_dO = produce_block_sparse_q_loads_bwd_sm90( blocksparse_tensors, batch_idx, head_idx, n_block, producer_state_Q, producer_state_dO, pipeline_Q, pipeline_dO, load_K, load_V, load_Q, load_dO, load_LSE, load_dPsum, self.tma_copy_bytes["K"], self.tma_copy_bytes["V"], Q_stage_eq_dO_stage=(self.Q_stage == self.dO_stage), subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @cute.jit def apply_score_mod( self, acc_S: cute.Tensor, thr_mma_SdP: cute.core.ThrMma, batch_idx, head_idx, m_block, n_block, softmax_scale, seqlen_info: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=(None, None), ): # [NOTE] SdP_swapAB: swapAB transposes the tile, so use (n, m) indexing cS = cute.make_identity_tensor( (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) ) cS = cute.domain_offset( (n_block * self.tile_n, m_block * self.tile_m) if self.SdP_swapAB else (m_block * self.tile_m, n_block * self.tile_n), cS, ) tScS = thr_mma_SdP.partition_C(cS) apply_score_mod_inner( acc_S, tScS, self.score_mod, batch_idx, head_idx, softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead, transpose_indices=self.SdP_swapAB, ) @cute.jit def apply_score_mod_bwd( self, grad_tensor: cute.Tensor, score_tensor: cute.Tensor, thr_mma_SdP: cute.core.ThrMma, batch_idx, head_idx, m_block, n_block, softmax_scale, seqlen_info: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=(None, None), ): cS = cute.make_identity_tensor( (self.tile_n, self.tile_m) if self.SdP_swapAB else (self.tile_m, self.tile_n) ) cS = cute.domain_offset( (n_block * self.tile_n, m_block * self.tile_m) if self.SdP_swapAB else (m_block * self.tile_m, n_block * self.tile_n), cS, ) tScS = thr_mma_SdP.partition_C(cS) apply_score_mod_bwd_inner( grad_tensor, score_tensor, tScS, self.score_mod_bwd, batch_idx, head_idx, softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead, transpose_indices=self.SdP_swapAB, ) @cute.jit def mma( self, tiled_mma_SdP: cute.TiledMma, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tiled_mma_dQ: cute.TiledMma, mdK: cute.Tensor, mdV: cute.Tensor, mdQaccum: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, sdO: cute.Tensor, sP: Optional[cute.Tensor], sdS: cute.Tensor, sLSE: cute.Tensor, sdPsum: cute.Tensor, sdQaccum: cute.Tensor, pipeline_Q: cutlass.pipeline.PipelineAsync, pipeline_dO: cutlass.pipeline.PipelineAsync, tidx: Int32, tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, r2s_tiled_copy_dQaccum: cute.TiledCopy, softmax_scale_log2: Float32, softmax_scale: Float32, block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), blocksparse_tensors: Optional[BlockSparseTensors] = None, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, is_dQ_wg: cutlass.Constexpr[bool] = True, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( self.num_wg_mma, stride=self.num_threads_per_warp_group ) thr_mma_SdP = tiled_mma_SdP.get_slice(tidx) wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dK = tiled_mma_dK.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dV = tiled_mma_dV.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_dQ = None if const_expr(is_dQ_wg): wg_idx_dQ = warp_group_idx if const_expr(self.num_wg_dQ > 1) else 0 wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout(wg_idx_dQ)) # S = Q @ K.T shape_mnk_S = (self.tile_m, self.tile_n, self.tile_hdim) _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( wg_mma_SdP, shape_mnk_S, sQ, sK, swap_AB=self.SdP_swapAB ) mma_qk_fn = partial( gemm_zero_init, tiled_mma_SdP, shape_mnk_S[:2], tSrQ, tSrK, swap_AB=self.SdP_swapAB ) # dP = dO @ V.T shape_mnk_dP = (self.tile_m, self.tile_n, self.tile_hdimv) _, tdPrdO, tdPrV = sm90_utils.partition_fragment_ABC( wg_mma_SdP, shape_mnk_dP, sdO, sV, swap_AB=self.SdP_swapAB ) mma_dov_fn = partial( gemm_zero_init, tiled_mma_SdP, shape_mnk_dP[:2], tdPrdO, tdPrV, swap_AB=self.SdP_swapAB ) # dV += P.T @ dO sPt = layout_utils.transpose_view(sP) if sP is not None else None sdOt = layout_utils.transpose_view(sdO) shape_mnk_dV = (self.tile_n, self.tile_hdimv, self.tile_m) acc_dV, tdVrPt, tdVrdOt = sm90_utils.partition_fragment_ABC( wg_mma_dV, shape_mnk_dV, sPt, sdOt, swap_AB=self.dKV_swapAB ) if const_expr(not self.mma_dkv_is_rs): mma_pdo_fn = partial( gemm_w_idx, tiled_mma_dV, acc_dV, tdVrPt, tdVrdOt, swap_AB=self.dKV_swapAB ) else: mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, acc_dV, tCrB=tdVrdOt) # dK += dS.T @ Q sdSt = layout_utils.transpose_view(sdS) sQt = layout_utils.transpose_view(sQ) shape_mnk_dK = (self.tile_n, self.tile_hdim, self.tile_m) acc_dK, tdKrdSt, tdKrQt = sm90_utils.partition_fragment_ABC( wg_mma_dK, shape_mnk_dK, sdSt, sQt, swap_AB=self.dKV_swapAB ) if const_expr(not self.mma_dkv_is_rs): mma_dsq_fn = partial( gemm_w_idx, tiled_mma_dK, acc_dK, tdKrdSt, tdKrQt, swap_AB=self.dKV_swapAB ) else: mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, acc_dK, tCrB=tdKrQt) # dQ = dS @ K sKt = layout_utils.transpose_view(sK) shape_mnk_dQ = (self.tile_m, self.tile_hdim, self.tile_n) mma_dsk_fn = None if const_expr(is_dQ_wg): _, tdQrdS, tdQrKt = sm90_utils.partition_fragment_ABC( wg_mma_dQ, shape_mnk_dQ, sdS, sKt, swap_AB=self.dQ_swapAB ) mma_dsk_fn = partial( gemm_zero_init, tiled_mma_dQ, shape_mnk_dQ[:2], tdQrdS, tdQrKt, swap_AB=self.dQ_swapAB, ) # Smem copy atom tiling for P/dS R2S copy_P_r2s = None mms_PdS = self.tile_n // (self.num_wg_mma // self.AtomLayoutMSdP) if const_expr(sP is not None): sP_cpy = sP if const_expr(not self.SdP_swapAB) else sPt copy_P_r2s, _, _ = copy_utils.get_smem_store_C( tiled_mma_SdP, sP_cpy, tidx, self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, ) sdS_cpy = sdS if const_expr(not self.SdP_swapAB) else sdSt copy_dS_r2s, _, _ = copy_utils.get_smem_store_C( tiled_mma_SdP, sdS_cpy, tidx, self.arch, transpose=self.SdP_swapAB, position_independent=True, major_mode_size=mms_PdS, ) tLSEsLSE = layout_utils.mma_partition_C_vec( sLSE, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB ) tLSEsdPsum = layout_utils.mma_partition_C_vec( sdPsum, thr_mma_SdP, expand_shape=self.tile_n, is_colvec=not self.SdP_swapAB ) # When shuffle=True, rows are distributed across 8 quads (4 threads each) within a warp. # Each thread loads only ceil(num_rows/8) values; shfl_copy = copy_utils.tiled_copy_1d(sLSE.element_type, num_threads=8, num_copy_elems=2) if const_expr(self.shuffle_LSE): tLSEsLSE = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsLSE) # ((2, 1), 1, 2) -> (((2, 1), 1), 2) tLSEsLSE = cute.group_modes(tLSEsLSE, 0, 2) if const_expr(self.shuffle_dPsum): tLSEsdPsum = shfl_copy.get_slice(cute.arch.lane_idx() // 4).partition_S(tLSEsdPsum) tLSEsdPsum = cute.group_modes(tLSEsdPsum, 0, 2) tdQsdQaccum = None if const_expr(is_dQ_wg): smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx) tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum) PdS_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads ) score_mod_fn = partial( self.apply_score_mod, thr_mma_SdP=thr_mma_SdP, softmax_scale=softmax_scale, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) score_mod_bwd_fn = partial( self.apply_score_mod_bwd, thr_mma_SdP=thr_mma_SdP, softmax_scale=softmax_scale, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mma_one_m_block_all = partial( self.mma_one_m_block, warp_group_idx=warp_group_idx, mma_qk_fn=mma_qk_fn, mma_dov_fn=mma_dov_fn, mma_pdo_fn=mma_pdo_fn, mma_dsq_fn=mma_dsq_fn, mma_dsk_fn=mma_dsk_fn, copy_P_r2s=copy_P_r2s, copy_dS_r2s=copy_dS_r2s, pipeline_Q=pipeline_Q, pipeline_dO=pipeline_dO, tLSEsLSE=tLSEsLSE, tLSEsdPsum=tLSEsdPsum, tdQsdQaccum=tdQsdQaccum, softmax_scale_log2=softmax_scale_log2, PdS_barrier=PdS_barrier, # acc_dV=acc_dV, # acc_dK=acc_dK, is_dQ_wg=is_dQ_wg, ) consumer_state_Q = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage ) consumer_state_dO = cutlass.pipeline.make_pipeline_state( cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mask = AttentionMaskCls(seqlen) score_mod_fn_cur = partial( score_mod_fn, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, seqlen_info=seqlen, ) score_mod_bwd_fn_cur = partial( score_mod_bwd_fn, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, seqlen_info=seqlen, ) m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) else: total_m_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = total_m_block_cnt > Int32(0) if process_tile: if const_expr(not self.use_block_sparsity): mask_fn = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, n_block=n_block, thr_mma=thr_mma_SdP, mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, mask_mod=self.mask_mod, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) dKV_accumulate = False for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): consumer_state_Q, consumer_state_dO = mma_one_m_block_all( m_block, consumer_state_Q, consumer_state_dO, mask_fn=mask_fn, score_mod_fn=score_mod_fn_cur, score_mod_bwd_fn=score_mod_bwd_fn_cur, dKV_accumulate=dKV_accumulate, ) dKV_accumulate = True else: consumer_state_Q, consumer_state_dO = consume_block_sparse_mma_bwd_sm90( blocksparse_tensors, batch_idx, head_idx, n_block, consumer_state_Q, consumer_state_dO, mma_one_m_block_all, mask, self.mask_mod, is_causal=self.is_causal, is_local=self.is_local, thr_mma_SdP=thr_mma_SdP, score_mod_fn=score_mod_fn_cur, score_mod_bwd_fn=score_mod_bwd_fn_cur, subtile_factor=self.subtile_factor, m_block_max=m_block_max, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) if const_expr(self.qhead_per_kvhead == 1): acc_dK.store(acc_dK.load() * softmax_scale) self.epilogue_dKV( acc_dV, mdV, sV, acc_dK, mdK, sK, seqlen, tma_atom_dK, tma_atom_dV, tiled_mma_dK, tiled_mma_dV, tidx, n_block, head_idx, batch_idx, qhead_per_kvhead_divmod, ) else: # KV tile with zero Q blocks produces no dK/dV; write zeros. if const_expr(self.use_block_sparsity or self.is_local or self.is_varlen_q): acc_dK.fill(0.0) acc_dV.fill(0.0) self.epilogue_dKV( acc_dV, mdV, sV, acc_dK, mdK, sK, seqlen, tma_atom_dK, tma_atom_dV, tiled_mma_dK, tiled_mma_dV, tidx, n_block, head_idx, batch_idx, qhead_per_kvhead_divmod, ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: cute.arch.cp_async_bulk_wait_group(0, read=True) @staticmethod @cute.jit def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: bool) -> Float32: """Retrieve the statistic for a given accumulator row. When shuffle=False, direct register indexing. When shuffle=True, warp shuffle from the thread group that holds the value. """ if const_expr(not shuffle): return tSrS[row] # tSrS: (((2, 1), 1), 1)), distributed across 8 threads in the warp vecsize = cute.size(tSrS, mode=[0, 0]) # 2 idx0, off, idx1 = cute.idx2crd(row, (vecsize, 8, cute.shape(tSrS, mode=[0, 1]))) # register index: 0, 1, 0, 1, ..., 2, 3, 2, 3, ... return utils.shuffle_sync(tSrS[idx0 + idx1 * vecsize], offset=off * 4 + (lane % 4)) @cute.jit def mma_one_m_block( self, m_block: Int32, consumer_state_Q: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, consumer_state_dO: cutlass.pipeline.PipelineState | pipeline.PipelineStateSimple, warp_group_idx: Int32, mma_qk_fn: Callable, mma_dov_fn: Callable, mma_pdo_fn: Callable, mma_dsq_fn: Callable, mma_dsk_fn: Callable, copy_P_r2s: Optional[Callable], copy_dS_r2s: Callable, pipeline_Q: cutlass.pipeline.PipelineAsync, pipeline_dO: cutlass.pipeline.PipelineAsync, tLSEsLSE: cute.Tensor, tLSEsdPsum: cute.Tensor, tdQsdQaccum: Optional[cute.Tensor], softmax_scale_log2: Float32, PdS_barrier: cutlass.pipeline.NamedBarrier, is_dQ_wg: cutlass.Constexpr[bool] = True, mask_fn: Optional[Callable] = None, score_mod_fn: Optional[Callable] = None, score_mod_bwd_fn: Optional[Callable] = None, dKV_accumulate: Boolean = True, ): consumer_state_dO_cur = ( consumer_state_Q if const_expr(self.Q_stage == self.dO_stage) else consumer_state_dO ) smem_idx_Q = consumer_state_Q.index smem_idx_dO = consumer_state_dO_cur.index if const_expr(self.dO_stage > 1) else 0 smem_idx_PdS = smem_idx_Q if const_expr(self.PdS_stage > 1) else 0 # (1) [GEMM 1] S = Q @ K^T pipeline_Q.consumer_wait(consumer_state_Q, pipeline_Q.consumer_try_wait(consumer_state_Q)) acc_S = mma_qk_fn(A_idx=smem_idx_Q, wg_wait=-1) # If shuffle_LSE, OOB reads are OK since sLSE is already padded tLSErLSE = copy_utils.load_s2r(tLSEsLSE[None, smem_idx_Q]) # (2) [GEMM 2] dP = dO @ V.T pipeline_dO.consumer_wait( consumer_state_dO_cur, pipeline_dO.consumer_try_wait(consumer_state_dO_cur) ) acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1) if const_expr(self.score_mod_bwd is not None): acc_S_pre = cute.make_fragment_like(acc_S) cute.autovec_copy(acc_S, acc_S_pre) if const_expr(self.score_mod is not None): score_mod_fn(acc_S, m_block=m_block) # (3) [Pointwise 1] P = exp(S - LSE) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB) lane_idx = cute.arch.lane_idx() for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])): lse_val = self._get_stat(tLSErLSE, r, lane_idx, shuffle=self.shuffle_LSE) for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True): acc_S_mn[r, c] = cute.math.exp2( acc_S_mn[r, c] * softmax_scale_log2 - lse_val, fastmath=True ) tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO]) # Convert P from f32 -> f16 tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype) # R2S for P if const_expr(not self.mma_dkv_is_rs): # sync to ensure P has already been used in the previous iteration before overwriting if const_expr(self.PdS_stage == 1): PdS_barrier.arrive_and_wait() copy_P_r2s(tdVrP, dst_idx=smem_idx_PdS) # (4) [Pointwise 2] dS = P*(dP-dPsum) warpgroup.wait_group(0) acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB) for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])): dpsum_val = self._get_stat(tLSErdPsum, r, lane_idx, shuffle=self.shuffle_dPsum) for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True): acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - dpsum_val) if const_expr(self.score_mod_bwd is not None): score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block) # Convert dS from f32 -> f16 tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype) # If there's double buffering on dS, we don't need to sync here. # Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. # But because both WGs have to sync at the end of the loop and double buffering, # this race condition is not possible. # This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and # (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs. if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)): cute.arch.fence_view_async_shared() PdS_barrier.arrive_and_wait() # R2S for dS copy_dS_r2s(tdKrdS, dst_idx=smem_idx_PdS) # (5) [GEMM 3] dV += P.T @ dO if const_expr(not self.mma_dkv_is_rs): mma_pdo_fn( A_idx=smem_idx_PdS, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1 ) else: mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1) # smem fence to make sure sdS is written before it's read by WGMMA cute.arch.fence_view_async_shared() PdS_barrier.arrive_and_wait() if const_expr(is_dQ_wg): # (6) [GEMM 4] dQ = dS @ K acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1) pipeline_dO.consumer_release(consumer_state_dO_cur) # release dO as dV mma is done # (7) [GEMM 5] dK += dS.T @ Q if const_expr(not self.mma_dkv_is_rs): mma_dsq_fn( A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 ) else: mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) # dQ R2S: wait for dQaccum_store to free the smem buffer, then write dQ to smem # When dQ_single_wg, only WG0 enters here so warp_group_idx == 0 cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) tdQrdQaccum_flat = cute.make_tensor( acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape) ) cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum) cute.arch.fence_view_async_shared() cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) warpgroup.wait_group(0) pipeline_Q.consumer_release(consumer_state_Q) else: # dQ_single_wg: WG1 skips dQ, only does dV wait + dK # (7) [GEMM 5] dK += dS.T @ Q if const_expr(not self.mma_dkv_is_rs): mma_dsq_fn( A_idx=smem_idx_PdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1 ) else: mma_dsq_fn(tCrA=tdKrdS, B_idx=smem_idx_Q, zero_init=not dKV_accumulate, wg_wait=1) pipeline_dO.consumer_release(consumer_state_dO_cur) warpgroup.wait_group(0) pipeline_Q.consumer_release(consumer_state_Q) consumer_state_Q.advance() consumer_state_dO.advance() return consumer_state_Q, consumer_state_dO @cute.jit def epilogue_dKV( self, acc_dV: cute.Tensor, mdV: cute.Tensor, sV: cute.Tensor, acc_dK: cute.Tensor, mdK: cute.Tensor, sK: cute.Tensor, seqlen: SeqlenInfoQK, tma_atom_dK: cute.CopyAtom, tma_atom_dV: cute.CopyAtom, tiled_mma_dK: cute.TiledMma, tiled_mma_dV: cute.TiledMma, tidx: Int32, n_block: Int32, head_idx: Int32, batch_idx: Int32, qhead_per_kvhead_divmod: Optional[FastDivmodDivisor] = None, ): epi_barrier = cutlass.pipeline.NamedBarrier( barrier_id=int(NamedBarrierBwd.Epilogue), num_threads=self.num_mma_threads ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if const_expr(self.qhead_per_kvhead == 1): mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3, ragged=self.varlen_k)[ None, None, head_idx ] mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3, ragged=self.varlen_k)[ None, None, head_idx ] gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0)) gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0)) store_dK, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dK, 0, cute.make_layout(1), sK, gdK, single_stage=True ) store_dV, _, _ = copy_utils.tma_get_copy_fn( tma_atom_dV, 0, cute.make_layout(1), sV, gdV, single_stage=True ) sdV = sV if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sV) sdK = sK if const_expr(not self.dKV_swapAB) else layout_utils.transpose_view(sK) copy_dV_r2s, _, _ = copy_utils.get_smem_store_C( tiled_mma_dV, sdV, tidx, self.arch, transpose=self.dKV_swapAB, position_independent=True, ) copy_dK_r2s, _, _ = copy_utils.get_smem_store_C( tiled_mma_dK, sdK, tidx, self.arch, transpose=self.dKV_swapAB, position_independent=True, ) cute.arch.cp_async_bulk_wait_group(1, read=True) epi_barrier.arrive_and_wait() copy_dV_r2s(acc_dV, dst_idx=None) cute.arch.fence_view_async_shared() epi_barrier.arrive_and_wait() if warp_idx == 4: store_dV() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(1, read=True) epi_barrier.arrive_and_wait() copy_dK_r2s(acc_dK, dst_idx=None) cute.arch.fence_view_async_shared() epi_barrier.arrive_and_wait() if warp_idx == 4: store_dK() cute.arch.cp_async_bulk_commit_group() else: sdKaccum_shape0 = self.tile_n * self.tile_hdim // self.num_wg_mma sdVaccum_shape0 = self.tile_n * self.tile_hdimv // self.num_wg_mma sdKaccum_layout = cute.make_layout((sdKaccum_shape0, self.num_wg_mma)) sdVaccum_layout = cute.make_layout((sdVaccum_shape0, self.num_wg_mma)) head_idx_kv = head_idx // qhead_per_kvhead_divmod mdKaccum_cur = seqlen.offset_batch_K( mdK, batch_idx, dim=2, padded=True, multiple=self.tile_hdim )[None, head_idx_kv] mdVaccum_cur = seqlen.offset_batch_K( mdV, batch_idx, dim=2, padded=True, multiple=self.tile_hdimv )[None, head_idx_kv] gdKaccum_ = cute.local_tile(mdKaccum_cur, (self.tile_n * self.tile_hdim,), (n_block,)) gdKaccum = cute.flat_divide(gdKaccum_, (sdKaccum_shape0,)) gdVaccum_ = cute.local_tile(mdVaccum_cur, (self.tile_n * self.tile_hdimv,), (n_block,)) gdVaccum = cute.flat_divide(gdVaccum_, (sdVaccum_shape0,)) # These two overlap each other sVaccum_ptr = cute.recast_ptr(sV.iterator, dtype=Float32) sdKaccum = cute.make_tensor(sVaccum_ptr, sdKaccum_layout) sdVaccum = cute.make_tensor(sVaccum_ptr, sdVaccum_layout) tiled_copy_dKVaccum_r2s = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128), cute.make_layout((self.num_threads_per_warp_group, self.num_wg_mma)), cute.make_layout(128 // Float32.width), ) thr_copy_dKVaccum_r2s = tiled_copy_dKVaccum_r2s.get_slice(tidx) tdKsdKaccum = thr_copy_dKVaccum_r2s.partition_D(sdKaccum) tdVsdVaccum = thr_copy_dKVaccum_r2s.partition_D(sdVaccum) cute.arch.cp_async_bulk_wait_group(0, read=True) epi_barrier.arrive_and_wait() tdKrdKaccum_flat = cute.make_tensor(acc_dK.iterator, tdKsdKaccum.shape) cute.autovec_copy(tdKrdKaccum_flat, tdKsdKaccum) cute.arch.fence_view_async_shared() epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( sdKaccum[None, wg_idx].iterator, gdKaccum[None, wg_idx].iterator, self.tma_copy_bytes["dKacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) epi_barrier.arrive_and_wait() tdVrdVaccum_flat = cute.make_tensor(acc_dV.iterator, tdVsdVaccum.shape) cute.autovec_copy(tdVrdVaccum_flat, tdVsdVaccum) cute.arch.fence_view_async_shared() epi_barrier.arrive_and_wait() if warp_idx == 4: with cute.arch.elect_one(): for wg_idx in cutlass.range_constexpr(self.num_wg_mma): copy_utils.cpasync_reduce_bulk_add_f32( sdVaccum[None, wg_idx].iterator, gdVaccum[None, wg_idx].iterator, self.tma_copy_bytes["dVacc"] // self.num_wg_mma, ) cute.arch.cp_async_bulk_commit_group() @cute.jit def dQaccum_store( self, mdQaccum: cute.Tensor, sdQaccum: cute.Tensor, block_info: BlockInfo, TileSchedulerCls: cutlass.Constexpr[Callable], SeqlenInfoCls: cutlass.Constexpr[Callable], blocksparse_tensors: Optional[BlockSparseTensors] = None, mdQ_semaphore: Optional[cute.Tensor] = None, ): tidx, _, _ = cute.arch.thread_idx() # warp-local thread index (dQaccum_store runs on warp 1, global tidx 32-63) warp_local_tidx = tidx % cute.arch.WARP_SIZE read_flag = const_expr(not self.deterministic) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: n_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) if const_expr(not seqlen.has_cu_seqlens_q): mdQaccum_cur = mdQaccum[None, head_idx, batch_idx] else: mdQaccum_cur = cute.domain_offset( (seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx] ) # ((M * K / num_wg_dQ, num_wg_dQ), num_m_blocks) gdQaccum = cute.local_tile( mdQaccum_cur, ( cute.make_layout( (self.tile_m * self.tile_hdim // self.num_wg_dQ, self.num_wg_dQ) ), ), (None,), ) if const_expr(mdQ_semaphore is not None): # mdQ_semaphore is (num_m_blocks, cluster_size, num_head, batch) after transpose mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block) if const_expr(not self.use_block_sparsity): process_tile = ( const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max ) loop_count = m_block_max - m_block_min else: total_block_cnt = get_total_q_block_count_bwd( blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) process_tile = total_block_cnt > Int32(0) if process_tile: if const_expr(not self.use_block_sparsity): for iter_idx in cutlass.range(loop_count, unroll=1): m_block = m_block_min + iter_idx m_block_safe = m_block num_dQ_chunks = self.num_wg_dQ for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): if const_expr(not self.deterministic): # If deterministic, we already waited at the end of the prev iter cute.arch.cp_async_bulk_wait_group( num_dQ_chunks - 1 - warp_group_idx, read=read_flag ) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) # Semaphore acquire: wait for prior n_blocks to finish writing this m_block if const_expr(self.deterministic): if const_expr(self.spt): _, n_block_max_for_m_block = block_info.get_n_block_min_max( seqlen, m_block_safe ) lock_value = n_block_max_for_m_block - 1 - n_block else: lock_value = n_block barrier.wait_eq( mdQ_semaphore_cur[(m_block_safe, None)].iterator, warp_local_tidx, 0, # flag_offset lock_value, ) for warp_group_idx in cutlass.range_constexpr(num_dQ_chunks): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE, ) with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, warp_group_idx].iterator, gdQaccum[(None, warp_group_idx), m_block_safe].iterator, self.tma_copy_bytes["dQ"], ) cute.arch.cp_async_bulk_commit_group() # Semaphore release: signal that this n_block is done with this m_block if const_expr(self.deterministic): cute.arch.cp_async_bulk_wait_group(0, read=read_flag) barrier.arrive_inc( mdQ_semaphore_cur[(m_block_safe, None)].iterator, warp_local_tidx, 0, # flag_offset 1, ) else: assert not self.deterministic, ( "Deterministic not implemented for block-sparse backward" ) dQaccum_store_block_sparse_bwd_sm90( blocksparse_tensors, batch_idx, head_idx, n_block, sdQaccum, gdQaccum, subtile_factor=self.subtile_factor, m_block_max=m_block_max, num_mma_warp_groups=self.num_wg_mma, num_threads_per_warp_group=self.num_threads_per_warp_group, tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], ) # For local masking + deterministic (non-spt): signal remaining m_blocks # that this n_block won't visit, so they don't deadlock waiting. if const_expr( self.deterministic and not self.spt and block_info.window_size_left is not None ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): barrier.arrive_inc( mdQ_semaphore_cur[(m_block, None)].iterator, warp_local_tidx, 0, # flag_offset 1, ) tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() if const_expr(not self.deterministic): cute.arch.cp_async_bulk_wait_group(0, read=True) ================================================ FILE: flash_attn/cute/flash_fwd.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of # https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm80.h # and https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_kernel_sm90.h # from Cutlass C++ to Cute-DSL. # Built on Cute-DSL example: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py import math from types import SimpleNamespace from typing import Type, Callable, Optional, List from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Constexpr, Float32, Int32, const_expr, Boolean from cutlass.cute.nvgpu import cpasync, warp import cutlass.utils as utils_basic from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL from quack import copy_utils from quack import layout_utils from flash_attn.cute import ampere_helpers as sm80_utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import Softmax from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute.named_barrier import NamedBarrierFwd from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: def __init__( self, dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: int = 1, is_causal: bool = False, is_local: bool = False, pack_gqa: bool = True, tile_m: int = 128, tile_n: int = 128, num_stages: int = 1, num_threads: int = 128, Q_in_regs: bool = False, score_mod: Optional[cutlass.Constexpr] = None, mask_mod: Optional[cutlass.Constexpr] = None, has_aux_tensors: bool = False, q_subtile_factor: int | None = None, ): """Initializes the configuration for a flash attention kernel. All contiguous dimensions must be at least 16 bytes aligned, which means that the head dimension should be a multiple of 8. :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int :param tile_n: n block size :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal :param score_mod: A callable that takes the attention scores and applies a modification. Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Any`` :param mask_mod: A callable that takes the attention scores and returns a boolean representing whether that score should be masked. Callable signature: ``mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors) -> Boolean`` """ self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) # Can save registers (and hence be faster) if we don't have to check hdim predication self.check_hdim_oob = head_dim != self.tile_hdim self.check_hdim_v_oob = head_dim_v != self.tile_hdimv self.qhead_per_kvhead = qhead_per_kvhead self.is_causal = is_causal self.is_local = is_local self.pack_gqa = pack_gqa self.tile_m = tile_m self.tile_n = tile_n self.num_threads = num_threads self.num_stages = num_stages self.q_subtile_factor = q_subtile_factor self.Q_in_regs = Q_in_regs self.score_mod = score_mod self.mask_mod = mask_mod self.qk_acc_dtype = Float32 self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) if self.vec_size > 2: raise ValueError( f"score_mod vec_size {self.vec_size} not supported on Sm80/90/120 " "due to accumulator thread ownership pattern." ) self.arch = BaseDSL._get_dsl().get_arch_enum() @staticmethod def can_implement( dtype, head_dim, head_dim_v, tile_m, tile_n, num_stages, num_threads, is_causal, Q_in_regs=False, ) -> bool: """Check if the kernel can be implemented with the given parameters. :param dtype: data type :type dtype: cutlass.Numeric :param head_dim: head dimension :type head_dim: int :param tile_m: m block size :type tile_m: int :param tile_n: n block size :type tile_n: int :param num_threads: number of threads :type num_threads: int :param is_causal: is causal :type is_causal: bool :return: True if the kernel can be implemented, False otherwise :rtype: bool """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if head_dim_v % 8 != 0: return False if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False # Check if block size setting is out of shared memory capacity # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size smem_usage_Q = tile_m * head_dim * 2 smem_usage_K = tile_n * head_dim * num_stages * 2 smem_usage_V = tile_n * head_dim_v * num_stages * 2 smem_usage_QV = ( (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) ) smem_usage = smem_usage_QV + smem_usage_K # TODO: sm86 and sm89 smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_80") if smem_usage > smem_capacity: return False # Check if twice the block size is divisible by the number of threads if (tile_m * 2) % num_threads != 0: return False return True def _check_type( self, mQ_type: Type[cutlass.Numeric], mK_type: Type[cutlass.Numeric], mV_type: Type[cutlass.Numeric], mO_type: Type[cutlass.Numeric], mLSE_type: Type[cutlass.Numeric] | None, mCuSeqlensQ_type: Type[cutlass.Numeric] | None, mCuSeqlensK_type: Type[cutlass.Numeric] | None, mSeqUsedQ_type: Type[cutlass.Numeric] | None, mSeqUsedK_type: Type[cutlass.Numeric] | None, ): # Get the data type and check if it is fp16 or bf16 if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mLSE_type not in [None, Float32]): raise TypeError("LSE tensor must be Float32") if const_expr(mCuSeqlensQ_type not in [None, Int32]): raise TypeError("cu_seqlens_q tensor must be Int32") if const_expr(mCuSeqlensK_type not in [None, Int32]): raise TypeError("cu_seqlens_k tensor must be Int32") if const_expr(mSeqUsedQ_type not in [None, Int32]): raise TypeError("seqused_q tensor must be Int32") if const_expr(mSeqUsedK_type not in [None, Int32]): raise TypeError("seqused_k tensor must be Int32") assert mQ_type == self.dtype def _setup_attributes(self): # /////////////////////////////////////////////////////////////////////////////// # Shared memory layout: Q/K/V # /////////////////////////////////////////////////////////////////////////////// sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom = ( self._get_smem_layout_atom() ) self.sQ_layout = cute.tile_to_shape( sQ_layout_atom, (self.tile_m, self.tile_hdim), (0, 1), ) self.sK_layout = cute.tile_to_shape( sK_layout_atom, (self.tile_n, self.tile_hdim, self.num_stages), (0, 1, 2), ) self.sV_layout = cute.tile_to_shape( sV_layout_atom, (self.tile_n, self.tile_hdimv, self.num_stages), (0, 1, 2), ) self.sO_layout = cute.tile_to_shape( sO_layout_atom, (self.tile_m, self.tile_hdimv), (0, 1), ) if const_expr(sP_layout_atom is not None): self.sP_layout = cute.tile_to_shape( sP_layout_atom, (self.tile_m, self.tile_n), (0, 1), ) else: self.sP_layout = None # /////////////////////////////////////////////////////////////////////////////// # GMEM Tiled copy: # /////////////////////////////////////////////////////////////////////////////// # Thread layouts for copies universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.dtype.width # atom_async_copy: async copy atom for QKV load atom_async_copy = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.dtype, num_bits_per_copy=universal_copy_bits, ) # atom_universal_copy: universal copy atom for O store atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=universal_copy_bits, ) # tQ_layout and tK_layout: thread layout for QK load tQK_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems assert self.num_Q_load_threads % tQK_shape_dim_1 == 0, ( "num_threads must be divisible by tQK_shape_dim_1" ) assert self.num_producer_threads % tQK_shape_dim_1 == 0, ( "num_threads must be divisible by tQK_shape_dim_1" ) tQ_layout = cute.make_ordered_layout( (self.num_Q_load_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), ) tK_layout = cute.make_ordered_layout( (self.num_producer_threads // tQK_shape_dim_1, tQK_shape_dim_1), order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we load Q assert self.tile_m % tQ_layout.shape[0] == 0 tV_shape_dim_1 = sV_layout_atom.outer.shape[1] // async_copy_elems tV_layout = cute.make_ordered_layout( (self.num_producer_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), ) # TODO: need a different layout for O if O dtype is not the same as V dtype # tO_layout: thread layout for O store tO_layout = cute.make_ordered_layout( (self.num_epilogue_threads // tV_shape_dim_1, tV_shape_dim_1), order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.tile_m % tO_layout.shape[0] == 0 # Value layouts for copies vQKV_layout = cute.make_layout((1, async_copy_elems)) vO_layout = vQKV_layout self.gmem_tiled_copy_Q = cute.make_tiled_copy_tv(atom_async_copy, tQ_layout, vQKV_layout) self.gmem_tiled_copy_K = cute.make_tiled_copy_tv(atom_async_copy, tK_layout, vQKV_layout) self.gmem_tiled_copy_V = cute.make_tiled_copy_tv(atom_async_copy, tV_layout, vQKV_layout) # gmem_tiled_copy_O: tiled copy for O store self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) def _get_smem_layout_atom(self): raise NotImplementedError() def _get_tiled_mma(self): raise NotImplementedError() def _get_shared_storage_cls(self): raise NotImplementedError() @cute.jit def __call__( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale: Float32, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ raise NotImplementedError() @cute.jit def epilogue( self, acc_O: cute.Tensor, lse: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], sO: cute.Tensor, seqlen: SeqlenInfoQK, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], tiled_mma: cute.TiledMma, tidx: Int32, m_block: Int32, head_idx: Int32, batch_idx: Int32, ): # store acc_O rO = cute.make_fragment_like(acc_O, self.dtype) rO.store(acc_O.load().to(self.dtype)) # Make sure all threads have finished reading V cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) taccOrO = smem_thr_copy_O.retile(rO) taccOsO = smem_thr_copy_O.partition_D(sO) # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) # copy acc O from rmem to smem with the smem copy atom cute.copy(smem_copy_atom_O, taccOrO, taccOsO) cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) pack_gqa = PackGQA( self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead ) # Write LSE from rmem -> gmem if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): mLSE_cur = mLSE[None, head_idx, batch_idx] else: offset = seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( gLSE.layout, cute.make_layout((self.tile_hdimv,), stride=(0,)) ) gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout) thr_mma = tiled_mma.get_slice(tidx) taccOgLSE = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gLSE_expanded)) assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse) taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0: for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])): if ( t0accOcO[m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - taccOcO[0][0] ): taccOgLSE[m, 0] = lse[m] else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] # thr_mma = tiled_mma.get_slice(tidx) # taccOgO = thr_mma.partition_C(gO) # cute.autovec_copy(rO, taccOgO) # sync to make sure all smem stores are done if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA cute.arch.fence_view_async_shared() cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, ) gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) store_O, _, _ = copy_utils.tma_get_copy_fn( tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 4: cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, ) store_O() cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(0, read=True) else: cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads, ) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO) tOrO = cute.make_fragment_like(tOsO, self.dtype) # load acc O from smem to rmem for wider vectorization cute.autovec_copy(tOsO, tOrO) if const_expr(not self.pack_gqa): gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) # copy acc O from rmem to gmem for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): if ( t0OcO[0, rest_m, 0][0] < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, ) else: pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @cute.jit def advance_pipeline(self, pipeline_index): return pipeline_index + 1 if pipeline_index < self.num_stages - 1 else 0 @cute.jit def load_Q( self, gmem_thr_copy: cute.TiledCopy, gQ: cute.Tensor, sQ: cute.Tensor, block: Int32, seqlen: Int32, headdim: Int32, ): tQsQ, tQgQ = gmem_thr_copy.partition_D(sQ), gmem_thr_copy.partition_S(gQ) cQ = cute.make_identity_tensor((self.tile_m, self.tile_hdim)) tQcQ = gmem_thr_copy.partition_S(cQ) t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) tQpQ = utils.predicate_k(tQcQ, limit=headdim) for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): # Instead of using tQcQ, we using t0QcQ and subtract the offset from the limit # (seqlen - block * kBlockM). This is because the entries of t0QcQ are known at compile time. if t0QcQ[0, m, 0][0] < seqlen - block * self.tile_m - tQcQ[0][0]: cute.copy( gmem_thr_copy, tQgQ[None, m, None], tQsQ[None, m, None], pred=tQpQ[None, m, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @cute.jit def load_K( self, gmem_tiled_copy: cute.TiledCopy, tKgK: cute.Tensor, tKsK: cute.Tensor, tKcK: cute.Tensor, t0KcK: cute.Tensor, tKpK: cute.Tensor, block: Int32, smem_pipe_write: Int32, seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load K? is_even_n_smem_k = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_k): # Instead of using tKcK, we using t0KcK and subtract the offset from the limit # (seqlen - block * kBlockN). This is because the entries of t0KcK are known at compile time. if const_expr(is_even_n_smem_k): seqlen_limit = seqlen - block * self.tile_n else: if const_expr(not need_predicates): seqlen_limit = self.tile_n else: seqlen_limit = cutlass.min(seqlen - block * self.tile_n, self.tile_n) seqlen_limit -= tKcK[0][0] for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): if t0KcK[0, n, 0][0] < seqlen_limit: cute.copy( gmem_tiled_copy, tKgK[None, n, None, block], tKsK[ None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 ], pred=tKpK[None, n, None] if const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sK smem tiles since we'll mask out the scores anyway. else: cute.copy( gmem_tiled_copy, tKgK[None, None, None, block], tKsK[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], pred=tKpK if const_expr(self.check_hdim_oob) else None, ) @cute.jit def load_V( self, gmem_tiled_copy: cute.TiledCopy, tVgV: cute.Tensor, tVsV: cute.Tensor, tVcV: cute.Tensor, t0VcV: cute.Tensor, tVpV: cute.Tensor, block: Int32, smem_pipe_write: Int32, seqlen: Int32, need_predicates: cutlass.Constexpr, ): # Do we need to check if we overshoot kBlockN when we load V? is_even_n_smem_v = self.tile_n % gmem_tiled_copy.tiler_mn[0].shape == 0 if const_expr(need_predicates or not is_even_n_smem_v): for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])): # If kBlockN doesn't evenly divide the tiled copy, only the last `n` needs to be checked if ( is_even_n_smem_v or n < cute.size(tVsV.shape[1]) - 1 or tVcV[0, n, 0][0] < self.tile_n ): predicate = tVpV[None, n, None] if const_expr(self.check_hdim_v_oob) else None if const_expr(need_predicates): seqlen_limit = seqlen - block * self.tile_n - tVcV[0][0] predicate_n = t0VcV[0, n, 0][0] < seqlen_limit predicate = cute.make_fragment_like(tVpV[None, 0, None]) for k in cutlass.range_constexpr(cute.size(predicate.shape[1])): for i in cutlass.range_constexpr(cute.size(predicate.shape[0])): predicate[i, k] = ( tVpV[i, n, k] if const_expr(self.check_hdim_v_oob) else True ) and predicate_n cute.copy( gmem_tiled_copy, tVgV[None, n, None, block], tVsV[ None, n, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0 ], pred=predicate, ) else: cute.copy( gmem_tiled_copy, tVgV[None, None, None, block], tVsV[None, None, None, smem_pipe_write if const_expr(self.num_stages > 1) else 0], pred=tVpV if const_expr(self.check_hdim_v_oob) else None, ) class FlashAttentionForwardSm80(FlashAttentionForwardBase): def _get_smem_layout_atom(self): sQ_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdim) sK_layout_atom = sQ_layout_atom sV_layout_atom = sm80_utils.get_smem_layout_atom(self.dtype, self.tile_hdimv) sO_layout_atom = sV_layout_atom sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom def _get_tiled_mma(self): tiled_mma_qk = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) tiled_mma_pv = cute.make_tiled_mma( warp.MmaF16BF16Op(self.dtype, Float32, (16, 8, 16)), (self.num_threads // 32, 1, 1), permutation_mnk=(self.num_threads // 32 * 16, 16, 16), ) return tiled_mma_qk, tiled_mma_pv def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct = [ cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], 1024] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) ] cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] @cute.struct class SharedStorageQKV: sV: sV_struct sQ: sQ_struct sK: sK_struct @cute.struct class SharedStorageSharedQV: sQ: sQV_struct sK: sK_struct return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], softmax_scale: Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, window_size_left: Optional[Int32] = None, window_size_right: Optional[Int32] = None, learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors=None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ assert learnable_sink is None, "Learnable sink is not supported in this kernel" self._check_type( *(t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)) ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_pv.size self.num_producer_threads = self.num_threads self.num_Q_load_threads = self.num_threads self.num_epilogue_threads = self.num_threads # self.use_tma_O = self.arch >= 90 and mCuSeqlensQ is None self.use_tma_O = self.arch >= Arch.sm_90 self._setup_attributes() SharedStorage = self._get_shared_storage_cls() mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] # Layout permutation: 4D non-varlen vs 3D varlen QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mQ, mO = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=QO_layout_transpose)) for t in (mQ, mO) ] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] if const_expr(mLSE is not None): LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) # TileScheduler for varlen, simple grid for non-varlen if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: TileScheduler = SingleTileScheduler num_batch = ( mCuSeqlensQ.shape[0] - 1 if const_expr(mCuSeqlensQ is not None) else mQ.shape[3] ) tile_sched_args = TileSchedulerArguments( num_block=cute.ceil_div(mQ.shape[0], self.tile_m), num_head=cute.size(mQ.shape[2]), num_batch=num_batch, num_splits=1, seqlen_k=0, headdim=mQ.shape[1], headdim_v=mV.shape[1], total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=(self.tile_m, self.tile_n), qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors) self.kernel( mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, softmax_scale_log2, softmax_scale, window_size_left, window_size_right, self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout, self.sP_layout, self.gmem_tiled_copy_Q, self.gmem_tiled_copy_K, self.gmem_tiled_copy_V, self.gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, SharedStorage, tile_sched_params, TileScheduler, aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], smem=SharedStorage.size_in_bytes(), stream=stream, ) @cute.kernel def kernel( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], softmax_scale_log2: Float32, softmax_scale: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, sP_layout: cute.ComposedLayout | None, gmem_tiled_copy_Q: cute.TiledCopy, gmem_tiled_copy_K: cute.TiledCopy, gmem_tiled_copy_V: cute.TiledCopy, gmem_tiled_copy_O: cute.TiledCopy, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, SharedStorage: cutlass.Constexpr, tile_sched_params, TileScheduler: cutlass.Constexpr[Callable], aux_tensors=None, fastdiv_mods=None, ): # Thread index, block index tidx, _, _ = cute.arch.thread_idx() tile_scheduler = TileScheduler.create(tile_sched_params) work_tile = tile_scheduler.initial_work_tile_info() m_block, num_head, batch_size, _ = work_tile.tile_idx block_info = BlockInfo( self.tile_m, self.tile_n, self.is_causal, self.is_local, False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) seqlen = SeqlenInfoQK.create( batch_idx=batch_size, seqlen_q_static=mQ.shape[0], seqlen_k_static=mK.shape[0], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # For varlen, wasted grid tiles (where batch_idx >= num_batch) will have # seqlen_q=seqlen_k=0 and n_block_max=0. Clamp to 0 so we don't use a # negative block index for K/V loads; the load/store predicates already # guard all memory accesses when seqlen is 0. n_block = cutlass.max(n_block_max - 1, 0) # /////////////////////////////////////////////////////////////////////////////// # Get the appropriate tiles for this thread block. # /////////////////////////////////////////////////////////////////////////////// blkQ_shape = (self.tile_m, self.tile_hdim) blkK_shape = (self.tile_n, self.tile_hdim) blkV_shape = (self.tile_n, self.tile_hdimv) num_head_kv = num_head // self.qhead_per_kvhead if const_expr(not seqlen.has_cu_seqlens_q): mQ_cur = mQ[None, None, num_head, batch_size] else: mQ_cur = cute.domain_offset((seqlen.offset_q, 0), mQ[None, None, num_head]) if const_expr(not seqlen.has_cu_seqlens_k): mK_cur = mK[None, None, num_head_kv, batch_size] mV_cur = mV[None, None, num_head_kv, batch_size] else: mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, num_head_kv]) mV_cur = cute.domain_offset((seqlen.offset_k, 0), mV[None, None, num_head_kv]) gQ = cute.local_tile(mQ_cur, blkQ_shape, (m_block, 0)) gK = cute.local_tile(mK_cur, blkK_shape, (None, 0)) gV = cute.local_tile(mV_cur, blkV_shape, (None, 0)) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) sQ = storage.sQ.get_tensor(sQ_layout) sK = storage.sK.get_tensor(sK_layout) if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout) else: sV = cute.make_tensor(cute.recast_ptr(sQ.iterator, dtype=self.dtype), sV_layout) # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = layout_utils.transpose_view(sV) gmem_thr_copy_K = gmem_tiled_copy_K.get_slice(tidx) gmem_thr_copy_V = gmem_tiled_copy_V.get_slice(tidx) # (CPY_Atom, CPY_N, CPY_K, n_block) tKsK, tKgK = gmem_thr_copy_K.partition_D(sK), gmem_thr_copy_K.partition_S(gK) # (CPY_Atom, CPY_N, CPY_K, n_block) tVsV, tVgV = gmem_thr_copy_V.partition_D(sV), gmem_thr_copy_V.partition_S(gV) # /////////////////////////////////////////////////////////////////////////////// # Tile MMA compute thread partitions and allocate accumulators # /////////////////////////////////////////////////////////////////////////////// thr_mma_qk = tiled_mma_qk.get_slice(tidx) thr_mma_pv = tiled_mma_pv.get_slice(tidx) tSrQ = thr_mma_qk.make_fragment_A(thr_mma_qk.partition_A(sQ)) tSrK = thr_mma_qk.make_fragment_B(thr_mma_qk.partition_B(sK[None, None, 0])) tOrVt = thr_mma_pv.make_fragment_B(thr_mma_pv.partition_B(sVt[None, None, 0])) acc_shape_O = thr_mma_pv.partition_shape_C((self.tile_m, self.tile_hdimv)) acc_O = cute.make_fragment(acc_shape_O, Float32) acc_O.fill(0.0) # /////////////////////////////////////////////////////////////////////////////// # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_QK = cute.make_copy_atom( warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), self.dtype, ) smem_copy_atom_V = cute.make_copy_atom( warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), self.dtype, ) smem_thr_copy_Q = utils.make_tiled_copy_A(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) smem_thr_copy_K = utils.make_tiled_copy_B(smem_copy_atom_QK, tiled_mma_qk).get_slice(tidx) smem_thr_copy_V = utils.make_tiled_copy_B(smem_copy_atom_V, tiled_mma_pv).get_slice(tidx) tSsQ = smem_thr_copy_Q.partition_S(sQ) tSsK = smem_thr_copy_K.partition_S(sK) tOsVt = smem_thr_copy_V.partition_S(sVt) # /////////////////////////////////////////////////////////////////////////////// # Predicate: Mark indices that need to copy when problem_shape isn't a multiple # of tile_shape # /////////////////////////////////////////////////////////////////////////////// # Construct identity layout for KV cK = cute.make_identity_tensor((self.tile_n, self.tile_hdim)) tKcK = gmem_thr_copy_K.partition_S(cK) t0KcK = gmem_thr_copy_K.get_slice(0).partition_S(cK) if const_expr(self.tile_hdim == self.tile_hdimv): tVcV = tKcK t0VcV = t0KcK else: cV = cute.make_identity_tensor((self.tile_n, self.tile_hdimv)) tVcV = gmem_thr_copy_V.partition_S(cV) t0VcV = gmem_thr_copy_V.get_slice(0).partition_S(cV) # Allocate predicate tensors for m and n, here we only allocate the tile of k, and # use "if" on the mn dimension. # This is to reduce register pressure and gets 2-3% performance gain. tKpK = utils.predicate_k(tKcK, limit=mK.shape[1]) if const_expr(self.same_hdim_kv): tVpV = tKpK else: tVpV = utils.predicate_k(tVcV, limit=mV.shape[1]) # shape: (atom_v_m * rest_m) softmax = Softmax.create( softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale, ) softmax.reset() # group parameters for compute_one_n_block mma_params = SimpleNamespace( thr_mma_qk=thr_mma_qk, thr_mma_pv=thr_mma_pv, tSrQ=tSrQ, tSrK=tSrK, tOrVt=tOrVt, acc_O=acc_O, ) smem_copy_params = SimpleNamespace( smem_thr_copy_Q=smem_thr_copy_Q, smem_thr_copy_K=smem_thr_copy_K, smem_thr_copy_V=smem_thr_copy_V, tSsQ=tSsQ, tSsK=tSsK, tOsVt=tOsVt, ) load_K = partial( self.load_K, gmem_tiled_copy_K, tKgK, tKsK, tKcK, t0KcK, tKpK, seqlen=seqlen.seqlen_k ) load_V = partial( self.load_V, gmem_tiled_copy_V, tVgV, tVsV, tVcV, t0VcV, tVpV, seqlen=seqlen.seqlen_k ) compute_one_n_block = partial( self.compute_one_n_block, mma_params=mma_params, smem_copy_params=smem_copy_params, softmax=softmax, load_K=load_K, load_V=load_V, score_mod=self.score_mod, batch_idx=batch_size, head_idx=num_head, m_block=m_block, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) # /////////////////////////////////////////////////////////////////////////////// # Prologue # /////////////////////////////////////////////////////////////////////////////// # Start async loads of the last mn-tile, where we take care of the mn residue gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, headdim=mQ.shape[1]) cute.arch.cp_async_commit_group() def preprocess_Q(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 1) if const_expr(self.Q_in_regs): cute.arch.barrier() tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) cute.copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view) # If Q_in_regs, we load Q, then load 1 stage of K, then (optionally) rotate Q and # read from smem_q to registers, then load V. # If !Q_in_regs, we load Q, load all stages of K & V, then (optionally) rotate Q. if const_expr(self.Q_in_regs): load_K(n_block, smem_pipe_write=0, need_predicates=True) cute.arch.cp_async_commit_group() preprocess_Q() cute.arch.barrier() # Make sure all threads have read smem_q before loading V for stage in cutlass.range_constexpr(self.num_stages): if const_expr(not self.Q_in_regs or stage > 0): if stage == 0 or n_block - stage >= 0: load_K(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(stage < self.num_stages - 1): if stage == 0 or n_block - stage >= 0: load_V(n_block - stage, smem_pipe_write=stage, need_predicates=stage == 0) cute.arch.cp_async_commit_group() if const_expr(not self.Q_in_regs): preprocess_Q() # /////////////////////////////////////////////////////////////////////////////// # Mainloop # /////////////////////////////////////////////////////////////////////////////// # Start processing of the first n-block. # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. mask = AttentionMask( self.tile_m, self.tile_n, seqlen, window_size_left, window_size_right, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) mask_fn = partial( mask.apply_mask, batch_idx=batch_size, head_idx=num_head, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods if const_expr(self.mask_mod is not None) else None, ) # First iteration with seqlen masking smem_pipe_read = Int32(0) smem_pipe_write = Int32(self.num_stages - 1) compute_one_n_block( n_block, smem_pipe_read, smem_pipe_write, is_first_n_block=True, seqlen=seqlen, mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - 1 - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 2 - n_tile compute_one_n_block( n_block, smem_pipe_read, smem_pipe_write, seqlen=seqlen, mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # The remaining iterations have no masking for n_tile in cutlass.range(n_block, unroll=1): compute_one_n_block( n_block - n_tile - 1, smem_pipe_read, smem_pipe_write, seqlen=seqlen, is_first_n_block=False, mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False) ) smem_pipe_read = self.advance_pipeline(smem_pipe_read) smem_pipe_write = self.advance_pipeline(smem_pipe_write) # TODO: local # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize() softmax.rescale_O(acc_O, row_scale) # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// # reuse sQ's data iterator sO = cute.make_tensor(sQ.iterator, sO_layout) self.epilogue( acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, gmem_tiled_copy_O, None, tiled_mma_pv, tidx, m_block, num_head, batch_size, ) @cute.jit def compute_one_n_block( self, n_block: Int32, smem_pipe_read: Int32, smem_pipe_write: Int32, mma_params: SimpleNamespace, smem_copy_params: SimpleNamespace, softmax: Softmax, load_K: Callable, load_V: Callable, score_mod: Callable | None, batch_idx: cutlass.Int32, head_idx: cutlass.Int32, m_block: cutlass.Int32, seqlen: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, ): """Compute one n_block of S/O. This function provides different variants for processing the first n block versus subsequent blocks. """ def sync(): cute.arch.cp_async_wait_group(self.num_stages * 2 - 2) cute.arch.barrier() acc_shape_S = mma_params.thr_mma_qk.partition_shape_C((self.tile_m, self.tile_n)) acc_S = cute.make_fragment(acc_shape_S, Float32) acc_S.fill(0.0) # wait for smem tile QK before mma calculation for S sync() # need predicates for the first tile def load_V_next(): if self.num_stages == 1 or n_block - self.num_stages + 1 >= 0: load_V( n_block - self.num_stages + 1, smem_pipe_write, need_predicates=is_first_n_block and self.num_stages == 1, ) cute.arch.cp_async_commit_group() load_V_next() sm80_utils.gemm( mma_params.thr_mma_qk, acc_S, mma_params.tSrQ, mma_params.tSrK, smem_copy_params.tSsQ, smem_copy_params.tSsK[ None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 ], smem_copy_params.smem_thr_copy_Q, smem_copy_params.smem_thr_copy_K, # hook_fn=load_V_next, A_in_regs=self.Q_in_regs, ) if const_expr(score_mod is not None): self.apply_score_mod( mma_params.thr_mma_qk, batch_idx, head_idx, m_block, acc_S, n_block, seqlen, softmax_scale=softmax.softmax_scale, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) smem_pipe_write = self.advance_pipeline(smem_pipe_write) def load_K_next(): if n_block - self.num_stages >= 0: load_K(n_block - self.num_stages, smem_pipe_write, need_predicates=False) cute.arch.cp_async_commit_group() # wait for smem tile V for O if const_expr(self.num_stages == 1): sync() load_K_next() if const_expr(mask_fn is not None): mask_fn(acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) softmax.rescale_O(mma_params.acc_O, row_scale) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) tOrP = layout_utils.reshape_acc_to_frgA(rP) if const_expr(self.num_stages > 1): sync() load_K_next() sm80_utils.gemm_rs( mma_params.thr_mma_pv, mma_params.acc_O, tOrP, mma_params.tOrVt, smem_copy_params.tOsVt[ None, None, None, smem_pipe_read if const_expr(self.num_stages > 1) else 0 ], smem_copy_params.smem_thr_copy_V, # hook_fn=load_K_next, ) # if const_expr(self.num_stages > 1): # load_K_next() # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") ================================================ FILE: flash_attn/cute/flash_fwd_combine.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # A reimplementation of https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_fwd_combine_kernel.h # from Cutlass C++ to Cute-DSL. import math from typing import Type, Optional from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr from flash_attn.cute import utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor class FlashAttentionForwardCombine: def __init__( self, dtype: Type[cutlass.Numeric], dtype_partial: Type[cutlass.Numeric], head_dim: int, tile_m: int = 8, k_block_size: int = 64, log_max_splits: int = 4, num_threads: int = 256, stages: int = 4, ): """ Forward combine kernel for split attention computation. :param dtype: output data type :param dtype_partial: partial accumulation data type :param head_dim: head dimension :param tile_m: m block size :param k_block_size: k block size :param log_max_splits: log2 of maximum splits :param num_threads: number of threads :param varlen: whether using variable length sequences :param stages: number of pipeline stages """ self.dtype = dtype self.dtype_partial = dtype_partial self.head_dim = head_dim self.tile_m = tile_m self.k_block_size = k_block_size self.max_splits = 1 << log_max_splits self.num_threads = num_threads self.is_even_k = head_dim % k_block_size == 0 self.stages = stages @staticmethod def can_implement( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, num_threads, ) -> bool: """Check if the kernel can be implemented with the given parameters.""" if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]: return False if dtype_partial not in [cutlass.Float16, cutlass.BFloat16, Float32]: return False if head_dim % 8 != 0: return False if num_threads % 32 != 0: return False if tile_m % 8 != 0: return False max_splits = 1 << log_max_splits if max_splits > 256: return False if (tile_m * max_splits) % num_threads != 0: return False return True def _setup_attributes(self): # GMEM copy setup for O partial universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.dtype_partial.width assert self.k_block_size % async_copy_elems == 0 k_block_gmem = ( 128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32) ) gmem_threads_per_row = k_block_gmem // async_copy_elems assert self.num_threads % gmem_threads_per_row == 0 # Async copy atom for O partial load atom_async_copy_partial = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), self.dtype_partial, num_bits_per_copy=universal_copy_bits, ) tOpartial_layout = cute.make_ordered_layout( (self.num_threads // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), ) vOpartial_layout = cute.make_layout((1, async_copy_elems)) # 4 vals per load self.gmem_tiled_copy_O_partial = cute.make_tiled_copy_tv( atom_async_copy_partial, tOpartial_layout, vOpartial_layout ) # GMEM copy setup for final O (use universal copy for store) atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.dtype, num_bits_per_copy=async_copy_elems * self.dtype.width, ) self.gmem_tiled_copy_O = cute.make_tiled_copy_tv( atom_universal_copy, tOpartial_layout, vOpartial_layout, # 4 vals per store ) # LSE copy setup with async copy (alignment = 1) lse_copy_bits = Float32.width # 1 element per copy, width is in bits m_block_smem = ( 128 if self.tile_m % 128 == 0 else ( 64 if self.tile_m % 64 == 0 else (32 if self.tile_m % 32 == 0 else (16 if self.tile_m % 16 == 0 else 8)) ) ) gmem_threads_per_row_lse = m_block_smem assert self.num_threads % gmem_threads_per_row_lse == 0 # Async copy atom for LSE load atom_async_copy_lse = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), Float32, num_bits_per_copy=lse_copy_bits, ) tLSE_layout = cute.make_ordered_layout( (self.num_threads // gmem_threads_per_row_lse, gmem_threads_per_row_lse), order=(1, 0), ) vLSE_layout = cute.make_layout(1) self.gmem_tiled_copy_LSE = cute.make_tiled_copy_tv( atom_async_copy_lse, tLSE_layout, vLSE_layout ) # /////////////////////////////////////////////////////////////////////////////// # Shared memory # /////////////////////////////////////////////////////////////////////////////// # Shared memory to register copy for LSE self.smem_threads_per_col_lse = self.num_threads // m_block_smem assert 32 % self.smem_threads_per_col_lse == 0 # Must divide warp size s2r_layout_atom_lse = cute.make_ordered_layout( (self.smem_threads_per_col_lse, self.num_threads // self.smem_threads_per_col_lse), order=(0, 1), ) self.s2r_tiled_copy_LSE = cute.make_tiled_copy_tv( cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), Float32), s2r_layout_atom_lse, cute.make_layout(1), ) # LSE shared memory layout with swizzling to avoid bank conflicts # This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts if const_expr(m_block_smem == 8): smem_lse_swizzle = cute.make_swizzle(5, 0, 5) elif const_expr(m_block_smem == 16): smem_lse_swizzle = cute.make_swizzle(4, 0, 4) else: smem_lse_swizzle = cute.make_swizzle(3, 2, 3) smem_layout_atom_lse = cute.make_composed_layout( smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0)) ) self.smem_layout_lse = cute.tile_to_shape( smem_layout_atom_lse, (self.max_splits, self.tile_m), (0, 1) ) # O partial shared memory layout (simple layout for pipeline stages) self.smem_layout_o = cute.make_ordered_layout( (self.tile_m, self.k_block_size, self.stages), order=(1, 0, 2) ) @cute.jit def __call__( self, mO_partial: cute.Tensor, mLSE_partial: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor] = None, cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, num_splits_dynamic_ptr: Optional[cute.Tensor] = None, varlen_batch_idx: Optional[cute.Tensor] = None, semaphore_to_reset: Optional[cute.Tensor] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): # Type checking if const_expr(not (mO_partial.element_type == self.dtype_partial)): raise TypeError("O partial tensor must match dtype_partial") if const_expr(not (mO.element_type == self.dtype)): raise TypeError("O tensor must match dtype") if const_expr(mLSE_partial.element_type not in [Float32]): raise TypeError("LSE partial tensor must be Float32") if const_expr(mLSE is not None and mLSE.element_type not in [Float32]): raise TypeError("LSE tensor must be Float32") # Shape validation - input tensors are in user format, need to be converted to kernel format if const_expr(len(mO_partial.shape) not in [4, 5]): raise ValueError( "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" ) if const_expr(len(mLSE_partial.shape) not in [3, 4]): raise ValueError( "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" ) if const_expr(len(mO.shape) not in [3, 4]): raise ValueError( "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" ) if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]): raise ValueError( "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" ) mO_partial, mO = [assume_tensor_aligned(t) for t in (mO_partial, mO)] # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b) # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h) O_partial_layout_transpose = ( [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2] ) # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h) mO_partial = cute.make_tensor( mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose) ) O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1] mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b) # or (num_splits, total_q, h) -> (total_q, num_splits, h) LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2] mLSE_partial = cute.make_tensor( mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose), ) # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h) LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1] mLSE = ( cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None ) # Determine if we have variable length sequences varlen = const_expr(cu_seqlens is not None or seqused is not None) self._setup_attributes() @cute.struct class SharedStorage: sLSE: cute.struct.Align[ cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128 ] sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.tile_m], 128] sO: cute.struct.Align[ cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128 ] smem_size = SharedStorage.size_in_bytes() # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch) seqlen = mO_partial.shape[0] num_head = mO_partial.shape[3] batch_size = ( mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1) ) # Create FastDivmodDivisor objects for efficient division seqlen_divmod = FastDivmodDivisor(seqlen) head_divmod = FastDivmodDivisor(num_head) grid_dim = ( cute.ceil_div(seqlen * num_head, self.tile_m), cute.ceil_div(self.head_dim, self.k_block_size), batch_size, ) self.kernel( mO_partial, mLSE_partial, mO, mLSE, cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, semaphore_to_reset, SharedStorage, self.smem_layout_lse, self.smem_layout_o, self.gmem_tiled_copy_O_partial, self.gmem_tiled_copy_O, self.gmem_tiled_copy_LSE, self.s2r_tiled_copy_LSE, seqlen_divmod, head_divmod, varlen, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], smem=smem_size, stream=stream, ) @cute.kernel def kernel( self, mO_partial: cute.Tensor, mLSE_partial: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], cu_seqlens: Optional[cute.Tensor], seqused: Optional[cute.Tensor], num_splits_dynamic_ptr: Optional[cute.Tensor], varlen_batch_idx: Optional[cute.Tensor], semaphore_to_reset: Optional[cute.Tensor], SharedStorage: cutlass.Constexpr, smem_layout_lse: cute.Layout | cute.ComposedLayout, smem_layout_o: cute.Layout, gmem_tiled_copy_O_partial: cute.TiledCopy, gmem_tiled_copy_O: cute.TiledCopy, gmem_tiled_copy_LSE: cute.TiledCopy, s2r_tiled_copy_LSE: cute.TiledCopy, seqlen_divmod: FastDivmodDivisor, head_divmod: FastDivmodDivisor, varlen: cutlass.Constexpr[bool], ): # Thread and block indices tidx, _, _ = cute.arch.thread_idx() m_block, k_block, maybe_virtual_batch = cute.arch.block_idx() # Map virtual batch index to real batch index (for persistent tile schedulers) batch_idx = ( varlen_batch_idx[maybe_virtual_batch] if const_expr(varlen_batch_idx is not None) else maybe_virtual_batch ) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) sLSE = storage.sLSE.get_tensor(smem_layout_lse) sMaxValidSplit = storage.sMaxValidSplit.get_tensor((self.tile_m,)) sO = storage.sO.get_tensor(smem_layout_o) # Handle semaphore reset — wait for dependent grids first if const_expr(semaphore_to_reset is not None): if ( tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and k_block == cute.arch.grid_dim()[1] - 1 and maybe_virtual_batch == cute.arch.grid_dim()[2] - 1 ): cute.arch.griddepcontrol_wait() semaphore_to_reset[0] = 0 # Get number of splits (use maybe_virtual_batch for per-batch-slot splits) num_splits = ( num_splits_dynamic_ptr[maybe_virtual_batch] if const_expr(num_splits_dynamic_ptr is not None) else mLSE_partial.shape[1] ) # Handle variable length sequences using SeqlenInfo seqlen_info = SeqlenInfo.create( batch_idx=batch_idx, seqlen_static=mO_partial.shape[0], cu_seqlens=cu_seqlens, seqused=seqused, # Don't need to pass in tile size since we won't use offset_padded ) seqlen, offset = seqlen_info.seqlen, seqlen_info.offset # Extract number of heads (head index will be determined dynamically) num_head = mO_partial.shape[3] max_idx = seqlen * num_head # Early exit for single split if dynamic if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and ( const_expr(not varlen) or m_block * self.tile_m < max_idx ): # Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial) cute.arch.griddepcontrol_wait() # =============================== # Step 1: Load LSE_partial from gmem to shared memory # =============================== mLSE_partial_cur = seqlen_info.offset_batch(mLSE_partial, batch_idx, dim=3) mLSE_partial_copy = cute.tiled_divide(mLSE_partial_cur, (1,)) gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_slice(tidx) tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE) # Create identity tensor for coordinate tracking cLSE = cute.make_identity_tensor((self.max_splits, self.tile_m)) tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE) # Load LSE partial values for m in cutlass.range(cute.size(tLSEcLSE, mode=[2]), unroll_full=True): mi = tLSEcLSE[0, 0, m][1] # Get m coordinate idx = m_block * self.tile_m + mi if idx < max_idx: # Calculate actual sequence position and head using FastDivmodDivisor if const_expr(not varlen): head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen mLSE_partial_cur_copy = mLSE_partial_copy[None, m_idx, None, head_idx] for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True): si = tLSEcLSE[0, s, 0][0] # Get split coordinate if si < num_splits: cute.copy( gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m], ) else: tLSEsLSE[None, s, m].fill(-Float32.inf) # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem cute.arch.cp_async_commit_group() # =============================== # Step 2: Load O_partial for pipeline stages # =============================== gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_slice(tidx) cO = cute.make_identity_tensor((self.tile_m, self.k_block_size)) tOcO = gmem_thr_copy_O_partial.partition_D(cO) tOsO_partial = gmem_thr_copy_O_partial.partition_D(sO) mO_partial_cur = seqlen_info.offset_batch(mO_partial, batch_idx, dim=4) # Precompute these values to avoid recomputing them in the loop num_rows = const_expr(cute.size(tOcO, mode=[1])) tOmidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) tOhidx = cute.make_rmem_tensor(num_rows, cutlass.Int32) tOrOptr = cute.make_rmem_tensor(num_rows, cutlass.Int64) for m in cutlass.range(num_rows, unroll_full=True): mi = tOcO[0, m, 0][0] # m coordinate idx = m_block * self.tile_m + mi if const_expr(not varlen): tOhidx[m], tOmidx[m] = divmod(idx, seqlen_divmod) else: tOhidx[m] = idx // seqlen tOmidx[m] = idx - tOhidx[m] * seqlen tOrOptr[m] = utils.elem_pointer( mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m]) ).toint() if idx >= max_idx: tOhidx[m] = -1 tOpO = None if const_expr(not self.is_even_k): tOpO = cute.make_rmem_tensor(cute.size(tOcO, mode=[2]), Boolean) for k in cutlass.range(cute.size(tOpO), unroll_full=True): tOpO[k] = tOcO[0, 0, k][1] < mO_partial.shape[1] - k_block * self.k_block_size # if cute.arch.thread_idx()[0] == 0 and k_block == 1: cute.print_tensor(tOpO) load_O_partial = partial( self.load_O_partial, gmem_tiled_copy_O_partial, tOrOptr, tOsO_partial, tOhidx, tOpO, tOcO, mO_partial_cur.layout, ) # Load first few stages of O_partial for stage in cutlass.range(self.stages - 1, unroll_full=True): if stage < num_splits: load_O_partial(stage, stage) cute.arch.cp_async_commit_group() # =============================== # Step 3: Load and transpose LSE from smem to registers # =============================== # Wait for LSE and initial O partial stages to complete cute.arch.cp_async_wait_group(self.stages - 1) cute.arch.sync_threads() # if cute.arch.thread_idx()[0] == 0: # # cute.print_tensor(sLSE) # for i in range(64): # cute.printf("sLSE[%d, 0] = %f", i, sLSE[i, 0]) # cute.arch.sync_threads() s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_slice(tidx) ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE) ts2rrLSE = cute.make_rmem_tensor_like(ts2rsLSE) cute.copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE) # =============================== # Step 4: Compute final LSE along split dimension # =============================== lse_sum = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Float32) ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE) # We compute the max valid split for each row to short-circuit the computation later max_valid_split = cute.make_rmem_tensor(cute.size(ts2rrLSE, mode=[2]), Int32) assert cute.size(ts2rrLSE, mode=[0]) == 1 # Compute max, scales, and final LSE for each row for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): # Find max LSE value across splits threads_per_col = const_expr(self.smem_threads_per_col_lse) lse_max = cute.arch.warp_reduction_max( ts2rrLSE[None, None, m] .load() .reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), threads_in_group=threads_per_col, ) # if cute.arch.thread_idx()[0] == 0: cute.printf(lse_max) # Find max valid split index max_valid_idx = -1 for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): if ts2rrLSE[0, s, m] != -Float32.inf: max_valid_idx = ts2rcLSE[0, s, 0][0] # Get split coordinate # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx) max_valid_split[m] = cute.arch.warp_reduction_max( max_valid_idx, threads_in_group=threads_per_col ) # Compute exp scales and sum lse_max_cur = ( 0.0 if lse_max == -Float32.inf else lse_max ) # In case all local LSEs are -inf LOG2_E = math.log2(math.e) lse_sum_cur = 0.0 for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True): scale = cute.math.exp2( ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True ) lse_sum_cur += scale ts2rrLSE[0, s, m] = scale # Store scale for later use lse_sum_cur = cute.arch.warp_reduction_sum( lse_sum_cur, threads_in_group=threads_per_col ) lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max # Normalize scales inv_sum = ( 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur ) ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum) # Store the scales exp(lse - lse_logsum) back to smem cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE) # Store max valid split to smem for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] if mi < self.tile_m: sMaxValidSplit[mi] = max_valid_split[m] # =============================== # Step 5: Store final LSE to gmem # =============================== if const_expr(mLSE is not None): if const_expr(cu_seqlens is None): mLSE_cur = mLSE[None, None, batch_idx] else: mLSE_cur = cute.domain_offset((offset, 0), mLSE) if k_block == 0: # Only first k_block writes LSE when mLSE is provided for m in cutlass.range(cute.size(ts2rrLSE, mode=[2]), unroll_full=True): if ts2rcLSE[0, 0, m][0] == 0: # Only thread responsible for s=0 writes mi = ts2rcLSE[0, 0, m][1] idx = m_block * self.tile_m + mi if idx < max_idx: if const_expr(not varlen): head_idx, m_idx = divmod(idx, seqlen_divmod) else: head_idx = idx // seqlen m_idx = idx - head_idx * seqlen mLSE_cur[m_idx, head_idx] = lse_sum[m] # =============================== # Step 6: Read O_partial and accumulate final O # =============================== cute.arch.sync_threads() # Get max valid split for this thread thr_max_valid_split = sMaxValidSplit[tOcO[0, 0, 0][0]] for m in cutlass.range(1, cute.size(tOcO, mode=[1]), unroll_full=True): thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[tOcO[0, m, 0][0]]) tOrO_partial = cute.make_rmem_tensor_like(tOsO_partial[None, None, None, 0]) tOrO = cute.make_rmem_tensor_like(tOrO_partial, Float32) tOrO.fill(0.0) stage_load = self.stages - 1 stage_compute = 0 # Main accumulation loop for s in cutlass.range(thr_max_valid_split + 1, unroll=4): # Get scales for this split scale = cute.make_rmem_tensor(num_rows, Float32) for m in cutlass.range(num_rows, unroll_full=True): scale[m] = sLSE[s, tOcO[0, m, 0][0]] # Get scale from smem # Load next stage if needed split_to_load = s + self.stages - 1 if split_to_load <= thr_max_valid_split: load_O_partial(split_to_load, stage_load) cute.arch.cp_async_commit_group() stage_load = 0 if stage_load == self.stages - 1 else stage_load + 1 # Wait for the current stage to be ready cute.arch.cp_async_wait_group(self.stages - 1) # We don't need __syncthreads() because each thread is just reading its own data from smem # Copy from smem to registers cute.autovec_copy(tOsO_partial[None, None, None, stage_compute], tOrO_partial) stage_compute = 0 if stage_compute == self.stages - 1 else stage_compute + 1 # Accumulate scaled partial results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0 and scale[m] > 0.0: tOrO[None, m, None].store( tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32) ) # =============================== # Step 7: Write final O to gmem # =============================== rO = cute.make_rmem_tensor_like(tOrO, self.dtype) rO.store(tOrO.load().to(self.dtype)) mO_cur = seqlen_info.offset_batch(mO, batch_idx, dim=3) if const_expr(cu_seqlens is None): mO_cur = mO[None, None, None, batch_idx] else: mO_cur = cute.domain_offset((offset, 0, 0), mO) mO_cur = utils.domain_offset_aligned((0, k_block * self.k_block_size, 0), mO_cur) elems_per_store = const_expr(cute.size(gmem_tiled_copy_O.layout_tv_tiled[1])) # mO_cur_copy = cute.tiled_divide(mO_cur, (1, elems_per_store,)) gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) # Write final results for m in cutlass.range(num_rows, unroll_full=True): if tOhidx[m] >= 0: mO_cur_copy = cute.tiled_divide( mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,) ) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_store if const_expr(self.is_even_k) or tOpO[k]: cute.copy(gmem_thr_copy_O, rO[None, m, k], mO_cur_copy[None, k_idx]) @cute.jit def load_O_partial( self, gmem_tiled_copy_O_partial: cute.TiledCopy, tOrOptr: cute.Tensor, tOsO_partial: cute.Tensor, tOhidx: cute.Tensor, tOpO: Optional[cute.Tensor], tOcO: cute.Tensor, mO_cur_partial_layout: cute.Layout, split: Int32, stage: Int32, ) -> None: elems_per_load = const_expr(cute.size(gmem_tiled_copy_O_partial.layout_tv_tiled[1])) tOsO_partial_cur = tOsO_partial[None, None, None, stage] for m in cutlass.range(cute.size(tOcO, [1]), unroll_full=True): if tOhidx[m] >= 0: o_gmem_ptr = cute.make_ptr( tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16 ) mO_partial_cur = cute.make_tensor( o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)) ) mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,)) for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True): k_idx = tOcO[0, 0, k][1] // elems_per_load if const_expr(tOpO is None) or tOpO[k]: cute.copy( gmem_tiled_copy_O_partial, mO_partial_cur_copy[None, k_idx, split], tOsO_partial_cur[None, m, k], ) ================================================ FILE: flash_attn/cute/flash_fwd_sm100.py ================================================ # Supported features: # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128, (192, 128). # - varlen # - sliding window # - split-kv # Unsupported features that will be added later: # - page size != 128 # - more hdim (192, 256) # Based on the cutlass example and cute-dsl example: # https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math from typing import Type, Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, Int64, Boolean, const_expr from cutlass.cute.nvgpu import cpasync import cutlass.cute.nvgpu.tcgen05 as tcgen05 import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass import pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL from quack import copy_utils, layout_utils from flash_attn.cute.paged_kv import PagedKVManager from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils import flash_attn.cute.pipeline as pipeline_custom from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ) class FlashAttentionForwardSm100: def __init__( self, # dtype: Type[cutlass.Numeric], head_dim: int, head_dim_v: Optional[int] = None, qhead_per_kvhead: cutlass.Constexpr[int] = 1, is_causal: bool = False, is_local: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, q_subtile_factor: int | None = None, m_block_size: int = 128, n_block_size: int = 128, q_stage: cutlass.Constexpr[int] = 2, is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, is_varlen_q: bool = False, use_2cta_instrs: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of) head_dim_v = head_dim_v if head_dim_v is not None else head_dim self.same_hdim_kv = head_dim == head_dim_v self.head_dim_v_padded = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of) self.same_hdim_kv_padded = self.head_dim_padded == self.head_dim_v_padded self.check_hdim_oob = head_dim != self.head_dim_padded self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.m_block_size = m_block_size self.n_block_size = n_block_size self.q_stage = q_stage assert self.q_stage in [1, 2] self.use_2cta_instrs = use_2cta_instrs # If split_P_arrive, the softmax warps write some columns of P first, signal to the MMA warp # to being the P @ V MMA, then write the rest of P and signal again. This allows some overlap # between compute the last couple columns of P and the P @ V MMA. self.split_P_arrive = n_block_size // 4 * 3 self.split_P_arrive = int(self.split_P_arrive / 32) * 32 # multiple of 32 assert self.split_P_arrive % 32 == 0 assert self.split_P_arrive < self.n_block_size self.arch = BaseDSL._get_dsl().get_arch_enum() assert self.arch >= Arch.sm_100 and self.arch <= Arch.sm_110f, "Only SM 10.x and 11.x are supported" self.cta_group_size = 2 if self.use_2cta_instrs else 1 # cta_tiler M includes only 1 CTA, the scheduler will take into account the cluster shape self.cta_tiler = (self.q_stage * m_block_size, n_block_size, self.head_dim_padded) # With 2CTA, the MMA tiler M covers both CTAs, so it's cta_group_size * m_block_size. # Each CTA owns m_block_size rows; the 2CTA MMA instruction spans both. self.mma_tiler_qk = (self.cta_group_size * m_block_size, n_block_size, self.head_dim_padded) self.mma_tiler_pv = (self.cta_group_size * m_block_size, self.head_dim_v_padded, n_block_size) self.qk_acc_dtype = Float32 self.pv_acc_dtype = Float32 self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = is_local self.is_varlen_q = is_varlen_q self.use_correction_warps_for_epi = is_varlen_q self.qhead_per_kvhead = qhead_per_kvhead self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa self.q_subtile_factor = q_subtile_factor if pack_gqa: assert m_block_size % self.qhead_per_kvhead == 0, ( "For PackGQA, m_block_size must be divisible by qhead_per_kvhead" ) assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( "SplitKV is not supported for hdim >= 192" ) self.score_mod = score_mod self.mask_mod = mask_mod self.vec_size: cutlass.Constexpr = getattr( score_mod, "__vec_size__", 1 if cutlass.const_expr(has_aux_tensors) else 2 ) # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f # self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103 self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 self.s0_s1_barrier = False self.overlap_sO_sQ = ( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or (self.head_dim_v_padded >= 128 and self.is_split_kv) ) if self.overlap_sO_sQ: self.is_persistent = False assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" ) self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 self.epilogue_warp_ids = (13,) self.load_warp_ids = (14,) self.empty_warp_ids = (15,) self.tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") self.threads_per_cta = cute.arch.WARP_SIZE * len( ( *self.softmax0_warp_ids, *self.softmax1_warp_ids, *self.correction_warp_ids, self.mma_warp_id, *self.load_warp_ids, *self.epilogue_warp_ids, *self.empty_warp_ids, ) ) if self.q_stage == 1: if not self.use_tma_KV: self.empty_warp_ids = self.empty_warp_ids + self.load_warp_ids self.load_warp_ids = self.softmax1_warp_ids else: self.empty_warp_ids = self.empty_warp_ids + self.softmax1_warp_ids self.softmax1_warp_ids = () elif not self.use_tma_KV: self.load_warp_ids = (14, 15) self.empty_warp_ids = () if self.use_correction_warps_for_epi: self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids self.epilogue_warp_ids = self.correction_warp_ids elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded for i in range(self.q_stage) ] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= self.tmem_alloc_cols self.tmem_s_to_p_offset = self.n_block_size // 2 self.tmem_p_offset = [ self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) ] # 0, 128 # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset if self.head_dim_padded < 96: self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 self.num_regs_correction = 64 self.num_regs_other = 48 if not paged_kv_non_tma else 80 else: # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 if not self.enable_ex2_emu: self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 else: # self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 # self.num_regs_softmax = 176 # self.num_regs_correction = 96 # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 if not self.enable_ex2_emu: self.num_regs_correction = 80 if not paged_kv_non_tma else 64 else: # self.num_regs_correction = 64 self.num_regs_correction = 80 if not paged_kv_non_tma else 64 # self.num_regs_other = 32 # self.num_regs_other = 64 # self.num_regs_other = 80 self.num_regs_other = 48 if not paged_kv_non_tma else 80 # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 self.buffer_align_bytes = 1024 def _setup_attributes(self): """Set up configurations and parameters for the FMHA kernel operation. This method initializes and configures various attributes required for the execution of the fused multi-head attention kernel, mainly about the pipeline stages: - Sets up staging parameters for Q, K, V inputs and accumulator data - Configures pipeline stages for softmax, correction, and epilogue operations """ smem_size_q = self.q_stage * self.m_block_size * self.head_dim_padded * self.q_dtype.width // 8 smem_size_o = self.q_stage * self.m_block_size * self.head_dim_v_padded * self.o_dtype.width // 8 smem_size_q_o = smem_size_q + smem_size_o if not self.overlap_sO_sQ else max(smem_size_q, smem_size_o) smem_size_k_per_stage = self.n_block_size * self.head_dim_padded * self.k_dtype.width // 8 smem_size_v_per_stage = self.n_block_size * self.head_dim_v_padded * self.v_dtype.width // 8 smem_size_kv_per_stage = max(smem_size_k_per_stage, smem_size_v_per_stage) // self.cta_group_size kv_stage = (224 * 1024 - smem_size_q_o) // smem_size_kv_per_stage if self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and kv_stage == 2: # For hdim 192,128, we can fit 3 stages if we use uneven_kv_smem kv_stage = 3 self.kv_stage = kv_stage # print("kv_stage", self.kv_stage) self.s_stage = 2 assert self.s_stage >= self.q_stage # For hdim 192,128 1CTA, we don't have enough smem to store all 3 stages of KV: # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. # Instead we store smem as [smem_large, smem_small, smem_large], where smem_large is # 128 x 192 and smem_small is 128 x 128. We set the stride between the stages to be # 128 * 160, so that indexing the 0th and 2nd stages will get the right address, # but for the 1st stage we need to add or subtract (depending on phase) 128 x 64. self.uneven_kv_smem = ( self.head_dim_padded == 192 and self.head_dim_v_padded == 128 and self.kv_stage == 3 ) self.uneven_kv_smem_offset = ( self.n_block_size * (self.head_dim_padded - self.head_dim_v_padded) // 2 if self.uneven_kv_smem else 0 ) assert self.uneven_kv_smem_offset % 1024 == 0 @cute.jit def __call__( self, mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): """Execute the Fused Multi-Head Attention operation on the provided tensors. This method prepares the input tensors for processing, validates their shapes and types, configures the computation parameters, and launches the CUDA kernel. The method handles: 1. Tensor layout transformations for specific memory access patterns 2. Validation of tensor shapes and data types 3. Initialization of hardware-specific parameters and memory layouts 4. Configuration of TMA (Tensor Memory Access) operations 5. Grid and work scheduling computation 6. Kernel launch with appropriate parameters """ # setup static attributes before smem/grid/tma computation self.q_dtype = mQ.element_type self.k_dtype = mK.element_type self.v_dtype = mV.element_type self.o_dtype = mO.element_type mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ = cute.make_tensor(mQ.iterator, cute.select(mQ.layout, mode=Q_layout_transpose)) # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there's cu_seqlens_k or (page_size, d, h_k, num_pages) if there's page_table KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [ cute.make_tensor(t.iterator, cute.select(t.layout, mode=KV_layout_transpose)) for t in (mK, mV) ] if const_expr(self.is_split_kv): O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] num_splits = mO.shape[0] else: O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] num_splits = Int32(1) mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose)) mLSE = ( cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if const_expr(mLSE is not None) else None ) # (s, d, h, b) -> (d, s, h, b) V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) # check type consistency if const_expr(self.q_dtype != self.k_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") self._setup_attributes() self.use_tma_O = self.arch >= Arch.sm_90 and mCuSeqlensQ is None and mSeqUsedQ is None # This can be tuned # This is currently very ad-hoc, we should tune it systematically self.ex2_emu_freq = 0 # self.ex2_emu_start_frg = 1 if self.is_causal else 0 self.ex2_emu_start_frg = 1 if const_expr(self.enable_ex2_emu): self.ex2_emu_freq = 16 if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs): self.ex2_emu_freq = 12 if const_expr( self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local ): self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 if const_expr(self.head_dim_padded > 64 and self.is_causal): self.ex2_emu_freq = 10 cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE q_major_mode = tcgen05.OperandMajorMode.K k_major_mode = tcgen05.OperandMajorMode.K v_major_mode = tcgen05.OperandMajorMode.MN self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) # the intermediate tensor p is from tmem & mK-major p_source = tcgen05.OperandSource.TMEM p_major_mode = tcgen05.OperandMajorMode.K tiled_mma_qk = sm100_utils_basic.make_trivial_tiled_mma( self.q_dtype, q_major_mode, k_major_mode, self.qk_acc_dtype, cta_group, self.mma_tiler_qk[:2], ) tiled_mma_pv = sm100_utils_basic.make_trivial_tiled_mma( self.v_dtype, p_major_mode, v_major_mode, self.pv_acc_dtype, cta_group, self.mma_tiler_pv[:2], p_source, ) self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) cta_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) ) # epi_tile is per-CTA (not full 2CTA) since each CTA writes its own O portion self.epi_tile = (self.m_block_size, self.head_dim_v_padded) sQ_layout = sm100_utils_basic.make_smem_layout_a( tiled_mma_qk, self.mma_tiler_qk, self.q_dtype, self.q_stage ) sK_layout = sm100_utils_basic.make_smem_layout_b( tiled_mma_qk, self.mma_tiler_qk, self.k_dtype, self.kv_stage ) tP_layout = sm100_utils_basic.make_smem_layout_a( tiled_mma_pv, self.mma_tiler_pv, self.q_dtype, self.s_stage ) sV_layout = sm100_utils_basic.make_smem_layout_b( tiled_mma_pv, self.mma_tiler_pv, self.v_dtype, self.kv_stage ) sO_layout = sm100_utils_basic.make_smem_layout_epi( self.o_dtype, self.o_layout, self.epi_tile, self.q_stage ) if const_expr(not self.same_hdim_kv_padded): # sK and sV are using the same physical smem so we need to adjust the stride so that they line up stride_sK = const_expr( max(sK_layout.outer.stride[-1], 0) ) # take max to turn tuple to Int32 stride_sV = const_expr(max(sV_layout.outer.stride[-1], 0)) stage_stride = const_expr( max(stride_sK, stride_sV) if not self.uneven_kv_smem else (stride_sK + stride_sV) // 2 ) sK_layout = cute.make_composed_layout( sK_layout.inner, 0, cute.make_layout( (*sK_layout.outer.shape[:-1], self.kv_stage), stride=(*sK_layout.outer.stride[:-1], stage_stride), ), ) sV_layout = cute.make_composed_layout( sV_layout.inner, 0, cute.make_layout( (*sV_layout.outer.shape[:-1], self.kv_stage), stride=(*sV_layout.outer.stride[:-1], stage_stride), ), ) if const_expr(self.pack_gqa): nheads_kv = mK.shape[2] mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) if const_expr(mLSE is not None): mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) for name, mX, layout in [ ("Q", mQ, sQ_layout), ("K", mK, sK_layout), ("V", mV, sV_layout), ] } for name in ("Q", "K", "V"): self.tma_copy_bytes[name] *= self.cta_group_size # TMA load for Q tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) tma_store_op = cpasync.CopyBulkTensorTileS2GOp() tma_atom_Q, mQ = cute.nvgpu.make_tiled_tma_atom_A( tma_load_op, mQ, cute.select(sQ_layout, mode=[0, 1, 2]), self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape, ) tma_atom_K = None tma_atom_V = None if const_expr(self.use_tma_KV): # TMA load for K tma_atom_K, mK = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mK, cute.select(sK_layout, mode=[0, 1, 2]), self.mma_tiler_qk, tiled_mma_qk, cta_layout_vmnk.shape, ) # TMA load for V tma_atom_V, mV = cute.nvgpu.make_tiled_tma_atom_B( tma_load_op, mV, cute.select(sV_layout, mode=[0, 1, 2]), self.mma_tiler_pv, tiled_mma_pv, cta_layout_vmnk.shape, ) self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) if const_expr(self.use_tma_O): tma_atom_O, mO = cpasync.make_tiled_tma_atom( tma_store_op, mO, cute.select(sO_layout, mode=[0, 1]), self.epi_tile ) gmem_tiled_copy_O = None else: tma_atom_O = None universal_copy_bits = 128 async_copy_elems = universal_copy_bits // self.o_dtype.width atom_universal_copy = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), self.o_dtype, num_bits_per_copy=universal_copy_bits, ) tO_shape_dim_1 = sO_layout.outer.shape[1][0] // async_copy_elems tO_layout = cute.make_ordered_layout( (self.num_epilogue_threads // tO_shape_dim_1, tO_shape_dim_1), order=(1, 0), ) # So that we don't have to check if we overshoot kBlockM when we store O assert self.m_block_size % tO_layout.shape[0] == 0 vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: if const_expr(self.is_causal or self.is_local): TileScheduler = SingleTileLPTScheduler else: TileScheduler = ( SingleTileScheduler if const_expr(not self.is_persistent) else StaticPersistentTileScheduler ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), num_splits, cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mQ.shape[1], mV.shape[0], # Note that this is different from Sm90 since we transpose mV in Sm100 total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=self.cta_tiler[:2], mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, element_size=self.k_dtype.width // 8, is_persistent=self.is_persistent, lpt=self.is_causal or self.is_local, is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) sO_size = cute.cosize(sO_layout) if const_expr(not self.overlap_sO_sQ) else 0 sQ_size = ( cute.cosize(sQ_layout) if const_expr(not self.overlap_sO_sQ) else cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) @cute.struct class SharedStorage: # m_barriers for pipelines mbar_load_Q: cute.struct.MemRange[Int64, self.q_stage * 2] mbar_load_KV: cute.struct.MemRange[Int64, self.kv_stage * 2] mbar_S_full_P_full_O_rescaled: cute.struct.MemRange[Int64, self.q_stage * 2] mbar_P_full_lastsplit: cute.struct.MemRange[Int64, self.q_stage * 2] mbar_O_full: cute.struct.MemRange[Int64, self.q_stage * 2] mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 2] # mbar_softmax_stats: cute.struct.MemRange[Int64, self.q_stage * 4 * 2] mbar_O_epi: cute.struct.MemRange[Int64, self.q_stage * 2] mbar_s0_s1_sequence: cute.struct.MemRange[Int64, 2 * 2] # Tmem dealloc cluster barrier tmem_dealloc_mbar_ptr: Int64 # Tmem holding buffer tmem_holding_buf: Int32 # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] sQ: cute.struct.Align[ cute.struct.MemRange[self.q_dtype, sQ_size], self.buffer_align_bytes ] sK: cute.struct.Align[ # cute.cosize(sK_layout) is correct even in the case of self.uneven_kv_smem cute.struct.MemRange[self.k_dtype, cute.cosize(sK_layout)], self.buffer_align_bytes, ] self.shared_storage = SharedStorage softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) window_size_left = Int32(window_size_left) if window_size_left is not None else None window_size_right = Int32(window_size_right) if window_size_right is not None else None fastdiv_mods = utils.compute_fastdiv_mods(mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable) head_divmod = None if cutlass.const_expr(self.pack_gqa): head_divmod = FastDivmodDivisor(self.qhead_per_kvhead) self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): raise NotImplementedError("Block sparsity + paged KV not supported on SM100") # Launch the kernel synchronously self.kernel( mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, mPageTable, tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, softmax_scale_log2, softmax_scale, window_size_left, window_size_right, learnable_sink, blocksparse_tensors, sQ_layout, sK_layout, tP_layout, sV_layout, sO_layout, gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, tile_sched_params, num_splits, aux_tensors, fastdiv_mods, head_divmod, ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None, stream=stream, min_blocks_per_mp=1, ) # GPU device kernel @cute.kernel def kernel( self, mQ: cute.Tensor, # (s_q, d, h, b) or (total_q, d, h) if there is cu_seqlens_q mK: cute.Tensor, # (s_k, d, h_k, b_k) or (total_k, d, h_k) if there is cu_seqlens_k or (page_size, d, h_k, num_pages) if there is page_table mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table mO: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, softmax_scale: Float32 | None, window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, gmem_tiled_copy_O: Optional[cute.TiledCopy], tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, num_splits: Int32, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, ): """The device kernel implementation of the Fused Multi-Head Attention. This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) 3. Softmax warps: Compute softmax normalization on attention scores 4. Correction warps: Apply adjustments to intermediate results 5. Epilogue warp: Handles final output transformation and storage The kernel implements a complex pipeline with overlapping computation and memory operations, using tensor memory access (TMA) for efficient data loading, warp specialization for different computation phases, and optional attention masking. """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): if const_expr(tma_atom is not None): cpasync.prefetch_descriptor(tma_atom) cta_layout_vmnk = cute.tiled_divide( cute.make_layout(self.cluster_shape_mnk), (tiled_mma_qk.thr_id.shape,) ) # Setup cta/thread coordinates bidx, _, _ = cute.arch.block_idx() if const_expr(cute.size(tiled_mma_qk.thr_id.shape) == 1): mma_tile_coord_v = 0 else: mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) tmem_alloc_barrier = pipeline.NamedBarrier( barrier_id=int(NamedBarrierFwdSm100.TmemPtr), num_threads=cute.arch.WARP_SIZE * len( (self.mma_warp_id, *self.softmax0_warp_ids, *self.softmax1_warp_ids, *self.correction_warp_ids) ), ) # Tensor memory dealloc barrier init tmem = cutlass.utils.TmemAllocator( storage.tmem_holding_buf, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.mma_warp_id, is_two_cta=self.use_2cta_instrs, two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, ) ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) load_warps = ThreadCooperativeGroup(len(self.load_warp_ids)) tma_warp = ThreadCooperativeGroup(1) softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) correction_threads = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.correction_warp_ids) ) # correction_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) softmax_correction_threads = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) ) epilogue_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)) # For UMMA-bridging pipelines: the non-MMA side spans both CTAs in the cluster, # so the thread count must include warps from both CTAs. softmax_warps_cluster = ThreadCooperativeGroup( len(self.softmax0_warp_ids) * self.cta_group_size ) correction_threads_cluster = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.correction_warp_ids) * self.cta_group_size ) softmax_correction_threads_cluster = ThreadCooperativeGroup( cute.arch.WARP_SIZE * len(self.softmax0_warp_ids + self.correction_warp_ids) * self.cta_group_size ) pipeline_q = pipeline_custom.PipelineTmaUmma.create( barrier_storage=storage.mbar_load_Q.data_ptr(), num_stages=self.q_stage, producer_group=tma_warp, consumer_group=mma_warp, tx_count=self.tma_copy_bytes["Q"], cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) if const_expr(self.use_tma_KV): pipeline_kv = pipeline_custom.PipelineTmaUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), num_stages=self.kv_stage, producer_group=tma_warp, consumer_group=mma_warp, tx_count=self.tma_copy_bytes["K"], cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) else: cpasync_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE ) pipeline_kv = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), num_stages=self.kv_stage, producer_group=cpasync_producer_group, consumer_group=mma_warp, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) # This pipeline is not the typical producer-consumer pipeline. The "producer" mma warp # uses it to signal that S is ready, and the softmax threads wait for S to be ready. # When softmax threads write P to tmem and the correction threads have rescaled O, they # signal as "consumer". The mma warp then waits for that signal to do the P @ V gemm. pipeline_s_p_o = pipeline_custom.PipelineUmmaAsync.create( barrier_storage=storage.mbar_S_full_P_full_O_rescaled.data_ptr(), num_stages=self.q_stage, producer_group=mma_warp, consumer_group=softmax_correction_threads_cluster, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) pipeline_p_lastsplit = pipeline_custom.PipelineAsyncUmma.create( barrier_storage=storage.mbar_P_full_lastsplit.data_ptr(), num_stages=self.q_stage, producer_group=softmax_warps_cluster, consumer_group=mma_warp, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) # MMA warp uses this to signal to the correction warps that O is ready. pipeline_o_acc = pipeline_custom.PipelineUmmaAsync.create( barrier_storage=storage.mbar_O_full.data_ptr(), num_stages=self.q_stage, producer_group=mma_warp, consumer_group=correction_threads_cluster, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, ) pipeline_s0_s1_sequence = None if const_expr(self.s0_s1_barrier and self.q_stage > 1): # This is not a typical producer-consumer pipeline. We will directly use # pipeline_s0_s1_sequence.sync_object_full and will not use # pipeline_s0_s1_sequence.sync_object_empty. pipeline_s0_s1_sequence = pipeline_custom.PipelineAsync.create( barrier_storage=storage.mbar_s0_s1_sequence.data_ptr(), num_stages=2, producer_group=softmax_threads, consumer_group=softmax_threads, defer_sync=True, ) pipeline_sm_stats = pipeline_custom.PipelineAsync.create( barrier_storage=storage.mbar_softmax_stats.data_ptr(), num_stages=self.q_stage, producer_group=softmax_threads, consumer_group=correction_threads, defer_sync=True, ) # Should put the NamedBarrier inside the pipeline class so we'll just have pipeline_sm_stats sm_stats_barrier = pipeline_custom.NamedBarrier( barrier_id=int(NamedBarrierFwdSm100.SoftmaxStatsW0), num_threads=cute.arch.WARP_SIZE * 2 ) pipeline_o_epi = None if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi = pipeline_custom.PipelineAsync.create( barrier_storage=storage.mbar_O_epi.data_ptr(), num_stages=self.q_stage, producer_group=correction_threads, consumer_group=epilogue_threads, defer_sync=True, ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) # Generate smem tensor Q/K/V/O # (MMA, MMA_Q, MMA_D, PIPE) sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) # (MMA, MMA_K, MMA_D, PIPE) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) # (MMA, MMA_K, MMA_D, PIPE) # Strip swizzle info to reuse smem sV = cute.make_tensor(cute.recast_ptr(sK.iterator, sV_layout.inner), sV_layout.outer) if const_expr(not self.overlap_sO_sQ): sO = storage.sO.get_tensor(sO_layout.outer, swizzle=sO_layout.inner) else: sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer) sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) thr_mma_qk = tiled_mma_qk.get_slice(mma_tile_coord_v) thr_mma_pv = tiled_mma_pv.get_slice(mma_tile_coord_v) qk_acc_shape = thr_mma_qk.partition_shape_C(self.mma_tiler_qk[:2]) # This is a fake tensor, by right we need to retrieve tmem_ptr. But we know that we always # request 512 columns of tmem, so we know that it starts at 0. tStS = thr_mma_qk.make_fragment_C(cute.append(qk_acc_shape, self.s_stage)) pv_acc_shape = thr_mma_pv.partition_shape_C(self.mma_tiler_pv[:2]) tOtO = thr_mma_pv.make_fragment_C(cute.append(pv_acc_shape, self.q_stage)) tOtO = cute.make_tensor(tOtO.iterator + self.tmem_o_offset[0], tOtO.layout) tP = cute.make_tensor(tStS.iterator, tP_layout.outer) tOrP = thr_mma_pv.make_fragment_A(tP)[None, None, None, 0] # Need to multiply by width ratio bc tP is in v_dtype but tmem offsets are in FP32 tP_width_ratio = Float32.width // self.v_dtype.width # Need to adjust the stage stride manually since the two stages aren't contiguous in tmem tP_stage_stride = (self.tmem_p_offset[1] - self.tmem_p_offset[0]) * tP_width_ratio tOrP = cute.make_tensor( tOrP.iterator + self.tmem_p_offset[0] * tP_width_ratio, cute.append(tOrP.layout, cute.make_layout((self.s_stage,), stride=(tP_stage_stride,))) ) block_info = BlockInfo( # This is cta_tiler, not mma_tiler_qk, since we move by block by (2 * mma_tiler[0], mma_tiler[1]) self.cta_tiler[0], self.cta_tiler[1], self.is_causal, self.is_local, self.is_split_kv, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, ) AttentionMaskCls = partial( AttentionMask, self.m_block_size, self.n_block_size, window_size_left=window_size_left, window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) # /////////////////////////////////////////////////////////////////////////////// # EMPTY # /////////////////////////////////////////////////////////////////////////////// for i in cutlass.range_constexpr(len(self.empty_warp_ids)): if warp_idx == self.empty_warp_ids[i]: cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]: cute.arch.setmaxregister_decrease(self.num_regs_other) self.load( thr_mma_qk, thr_mma_pv, mQ, mK, mV, sQ, sK, sV, mPageTable, tma_atom_Q, tma_atom_K, tma_atom_V, pipeline_q, pipeline_kv, block_info, num_splits, SeqlenInfoCls, TileSchedulerCls, blocksparse_tensors, ) # /////////////////////////////////////////////////////////////////////////////// # MMA # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) # Alloc tensor memory buffer tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) self.mma( tiled_mma_qk, tiled_mma_pv, sQ, sK, sV, tStS, tOtO, tOrP, pipeline_q, pipeline_kv, pipeline_s_p_o, pipeline_p_lastsplit, pipeline_o_acc, is_leader_cta, block_info, num_splits, SeqlenInfoCls, TileSchedulerCls, blocksparse_tensors, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() tmem_alloc_barrier.arrive_and_wait() tmem.free(tmem_ptr) # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// if const_expr(not self.use_correction_warps_for_epi): if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: cute.arch.setmaxregister_decrease(self.num_regs_other) self.epilogue_s2g( mO, sO, gmem_tiled_copy_O, tma_atom_O, pipeline_o_epi, block_info, num_splits, SeqlenInfoCls, TileSchedulerCls, mma_tile_coord_v, ) # /////////////////////////////////////////////////////////////////////////////// # Softmax # /////////////////////////////////////////////////////////////////////////////// if ( (const_expr(self.q_stage == 2) and warp_idx <= self.softmax1_warp_ids[-1]) or (const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1]) ): # increase register after decreasing cute.arch.setmaxregister_increase(self.num_regs_softmax) # sync with mma warp before retrieving tmem ptr tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, softmax_scale=softmax_scale, thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, pipeline_s_p_o=pipeline_s_p_o, pipeline_p_lastsplit=pipeline_p_lastsplit, pipeline_sm_stats=pipeline_sm_stats, sm_stats_barrier=sm_stats_barrier, pipeline_s0_s1_sequence=pipeline_s0_s1_sequence, learnable_sink=learnable_sink, block_info=block_info, num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, ) if const_expr(not self.s0_s1_barrier): stage = Int32(0 if const_expr(self.q_stage == 1) or warp_idx < self.softmax1_warp_ids[0] else 1) softmax_loop(stage=stage, tStS=tStS) else: # If there's s0_s1_barrier, it's faster to have 2 WGs having different code if warp_idx < self.softmax1_warp_ids[0]: softmax_loop(stage=0, tStS=tStS) if warp_idx < self.correction_warp_ids[0] and warp_idx >= self.softmax1_warp_ids[0]: softmax_loop(stage=1, tStS=tStS) tmem_alloc_barrier.arrive() # /////////////////////////////////////////////////////////////////////////////// # Correction # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_correction) # sync with mma warp before retrieving tmem ptr tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype) self.correction_loop( thr_mma_qk, thr_mma_pv, tStS, tOtO, sScale, mO, mLSE, sO, pipeline_s_p_o, pipeline_o_acc, pipeline_sm_stats, sm_stats_barrier, pipeline_o_epi, learnable_sink, gmem_tiled_copy_O, tma_atom_O, softmax_scale_log2, block_info, num_splits, SeqlenInfoCls, TileSchedulerCls, blocksparse_tensors, ) tmem_alloc_barrier.arrive() return @cute.jit def load( self, thr_mma_qk: cute.core.ThrMma, thr_mma_pv: cute.core.ThrMma, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, mPageTable: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) q_producer_phase = Int32(1) kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] tiler_gQ = ((self.mma_tiler_qk[0] * self.q_stage), self.head_dim_padded) gQ = cute.local_tile(mQ_cur, tiler_gQ, (m_block, 0)) # (128 * 2, 128) gQ = layout_utils.select( cute.flat_divide(gQ, (self.mma_tiler_qk[0],)), mode=[0, 2, 1] ) # (128, 128, 2) head_idx_kv = ( head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx ) if const_expr(mPageTable is None): if const_expr(not seqlen.has_cu_seqlens_k): mK_cur, mV_cur = [t[None, None, head_idx_kv, batch_idx] for t in (mK, mV)] else: mK_cur = cute.domain_offset((seqlen.offset_k, 0), mK[None, None, head_idx_kv]) mV_cur = cute.domain_offset((0, seqlen.offset_k), mV[None, None, head_idx_kv]) gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0)) gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None)) else: # Need to keep batch coord None since we'll index into it with page idx mK_cur, mV_cur = [t[None, None, head_idx_kv, None] for t in (mK, mV)] gK = cute.local_tile( mK_cur, cute.select(self.mma_tiler_qk, mode=[1, 2]), (None, 0, None) ) gV = cute.local_tile( mV_cur, cute.select(self.mma_tiler_pv, mode=[1, 2]), (0, None, None) ) tSgQ = thr_mma_qk.partition_A(gQ) tSgK = thr_mma_qk.partition_B(gK) tOgV = thr_mma_pv.partition_B(gV) load_Q_fn, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), tSgQ, sQ ) if const_expr(self.use_tma_KV): tKsK, tKgK = cpasync.tma_partition( tma_atom_K, 0, # no multicast cute.make_layout(1), cute.group_modes(sK, 0, 3), cute.group_modes(tSgK, 0, 3), ) tVsV, tVgV = cpasync.tma_partition( tma_atom_V, 0, # no multicast cute.make_layout(1), cute.group_modes(sV, 0, 3), cute.group_modes(tOgV, 0, 3), ) paged_kv_manager = None else: page_size = mK.shape[0] paged_kv_manager = PagedKVManager.create( mPageTable, mK, mV, FastDivmodDivisor(page_size), batch_idx, head_idx_kv, tidx, seqlen.seqlen_k, 0, # leftpad_k self.n_block_size, self.head_dim_padded, self.head_dim_v_padded, num_load_threads, mK.element_type, ) tKsK, tKgK = None, None tVsV, tVgV = None, None load_Q = partial(self.load_Q, load_Q_fn, pipeline_q=pipeline_q, phase=q_producer_phase) load_K = partial( self.load_KV, tma_atom_K, tKgK, tKsK, paged_kv_manager, sK, pipeline_kv=pipeline_kv, K_or_V="K", ) load_V = partial( self.load_KV, tma_atom_V, tVgV, tVsV, paged_kv_manager, sV, pipeline_kv=pipeline_kv, K_or_V="V", ) if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max( seqlen, m_block, split_idx, num_splits ) if const_expr(not self.is_split_kv) or n_block_min < n_block_max: n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( mPageTable[batch_idx, n_block_first] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block_first) load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 # load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx, extra_tx_count=self.tma_copy_bytes["Q"]) # K0 if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: # load_Q(block=0, stage=0) # Q0 pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) # pipeline_q.sync_object_empty.wait(0, q_producer_phase) tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(0) # tma_bar_ptr = pipeline_kv.producer_get_barrier(kv_producer_state) load_Q_fn(src_idx=0, dst_idx=0, tma_bar_ptr=tma_bar_ptr) kv_producer_state.advance() if const_expr(self.q_stage == 2) and (const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]): # load_Q(block=1, stage=1) # Q1 pipeline_q.producer_acquire_w_index_phase(1, q_producer_phase) tma_bar_ptr = pipeline_q.sync_object_full.get_barrier(1) load_Q_fn(src_idx=1, dst_idx=1, tma_bar_ptr=tma_bar_ptr) q_producer_phase ^= 1 load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 2 - i page_idx = ( mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki kv_producer_state.advance() load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi kv_producer_state.advance() else: kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( blocksparse_tensors, batch_idx, head_idx, m_block, kv_producer_state, load_Q, load_K, load_V, pipeline_kv, self.q_stage, q_producer_phase, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop pipeline_kv.producer_tail(kv_producer_state) # This is equivalent to pipeline_q.producer_tail if const_expr(len(self.load_warp_ids) == 1) or warp_idx == self.load_warp_ids[0]: pipeline_q.producer_acquire_w_index_phase(self.q_stage - 1, q_producer_phase) @cute.jit def mma( self, tiled_mma_qk: cute.core.ThrMma, tiled_mma_pv: cute.core.ThrMma, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, tStS: cute.Tensor, tOtO: cute.Tensor, tOrP: cute.Tensor, pipeline_q: pipeline.PipelineAsync, pipeline_kv: pipeline.PipelineAsync, pipeline_s_p_o: pipeline.PipelineAsync, pipeline_p_lastsplit: pipeline.PipelineAsync, pipeline_o_acc: pipeline.PipelineAsync, is_leader_cta: Boolean, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) tOrV = tiled_mma_pv.make_fragment_B(sV) if const_expr(self.q_stage == 2): tSrQs = (tSrQ[None, None, None, 0], tSrQ[None, None, None, 1]) else: tSrQs = (tSrQ[None, None, None, 0],) qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op) q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K) k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K) v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN) q_smem_start = [sm100_desc.make_smem_desc_start_addr(sQ[None, None, None, stage].iterator) for stage in range(self.q_stage)] sm100_utils.declare_ptx_smem_desc(q_smem_start[self.q_stage - 1], q_smem_base, tSrQ[None, None, None, 0].layout, var_name_prefix="fa_fwd_q_smem_desc") sm100_utils.declare_ptx_idesc(qk_mma_op, var_name="fa_fwd_qk_mma_idesc") sm100_utils.declare_ptx_idesc(pv_mma_op, var_name="fa_fwd_pv_mma_idesc") sQ_stage_stride = (sQ.layout.stride[-1] * sQ.element_type.width // 8) >> 4 if const_expr(self.q_stage == 1): sQ_stage_stride = 0 gemm_Si = [ partial( # sm100_utils.gemm_ptx_precomputed, # self.tmem_s_offset[stage], # smem_desc_start_a=q_smem_start[stage], # idesc=qk_mma_idesc, # smem_desc_base_a=q_smem_base, # smem_desc_base_b=k_smem_base, # tCrA_layout=tSrQ[None, None, None, 0].layout, sm100_utils.gemm_ptx_precomputed_varname, self.tmem_s_offset[stage], # idesc=qk_mma_idesc, smem_desc_base_b=k_smem_base, tCrB_layout=tSrK[None, None, None, 0].layout, smem_var_name_prefix=f"fa_fwd_q_smem_desc", idesc_var_name=f"fa_fwd_qk_mma_idesc", smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, cta_group=self.cta_group_size, ) for stage in range(self.q_stage) ] # gemm_Si = [ # partial( # sm100_utils.gemm, # tiled_mma_qk, # tStS[None, None, None, stage], # tCrA=tSrQ[None, None, None, stage], # zero_init=True, # ) # for stage in range(self.q_stage) # ] gemm_Pi = [ partial( # sm100_utils.gemm_ptx_precomputed, sm100_utils.gemm_ptx_partial, pv_mma_op, self.tmem_o_offset[stage], tOrP[None, None, None, stage], sA=None, split_arrive=self.split_P_arrive if self.split_P_arrive > 0 else None, # smem_desc_start_a=tOrP[None, None, None, stage].iterator.toint(), # smem_desc_start_a=self.tmem_p_offset[stage], # idesc=pv_mma_idesc, # smem_desc_base_a=None, # smem_desc_base_b=v_smem_base, # tCrA_layout=tOrP[None, None, None, 0].layout, # tCrB_layout=tOrV[None, None, None, 0].layout cta_group=self.cta_group_size, ) for stage in range(self.q_stage) ] # gemm_Pi = [ # partial( # sm100_utils.gemm, tOtO[None, None, None, stage], tCrA=tOrP[None, None, None, stage] # ) # for stage in range(self.q_stage) # ] mma_q_consumer_phase = Int32(0) mma_kv_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.kv_stage ) P_full_O_rescaled_phase = Int32(0) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) block_iter_count = Int32(0) process_tile = False if const_expr(self.use_block_sparsity): block_iter_count = get_total_block_count( blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) process_tile = block_iter_count > Int32(0) else: n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) block_iter_count = n_block_max - n_block_min if const_expr(not self.is_split_kv): process_tile = True else: process_tile = n_block_min < n_block_max if process_tile and is_leader_cta: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 pipeline_q.consumer_wait_w_index_phase(stage, mma_q_consumer_phase) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tSrKi = tSrK[None, None, None, Ki_index] # We don't need to acquire empty S0 / S1. # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 # are empty. For subsequent iterations, the wait happened at the end # of the while loop. # 3. gemm # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrKi, zero_init=True) sK_cur = sK[None, None, None, Ki_index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) # gemm_Si[stage](tCrB=tSrKi, sB=sK_cur) gemm_Si[stage]( smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator) ) # gemm_Si[stage](tCrB=tSrKi) # 4. release S0 / S1 pipeline_s_p_o.producer_commit_w_index(stage) mma_q_consumer_phase ^= 1 # 5. release K0 pipeline_kv.consumer_release(mma_kv_consumer_state) mma_kv_consumer_state.advance() # End of GEMM (Q1 * K0 -> S1) # Note: Q0 & Q1 are still needed in the seqlen_kv loop # so we need to release them after the seqlen_kv loop # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate block_loop_count = block_iter_count - 1 O_should_accumulate = False for i in cutlass.range(block_loop_count, unroll=1): # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) mma_kv_release_state = mma_kv_consumer_state.clone() Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(self.q_stage): # 2. acquire corrected O0/O1_partial and P0 / P1 # For the first iteration in this work tile, waiting for O0/O1_partial # means that the correction warps has finished reading tO during # the last iteration of the previous work tile. pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) sV_cur = sV[None, None, None, Vi_index] if const_expr(self.uneven_kv_smem): sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) gemm_Pi[stage]( tCrB=tOrVi, sB=sV_cur, # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator), zero_init=not O_should_accumulate, mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None, mbar_phase=P_full_O_rescaled_phase, ) # Don't need to signal O_full to the correction warps since the # correction warps wait for the softmax warps anyway. By the time the softmax # warps finished, S_i for the next iteration must have been done, so O_i-1 # must have been done as well. # pipeline_o_acc.producer_commit_w_index(stage) # 4. release V(i-1) if const_expr(stage == self.q_stage - 1): pipeline_kv.consumer_release(mma_kv_release_state) mma_kv_release_state.advance() # End of GEMM_PV00 (P0 * V0 -> O0_partial) # GEMM_QK0i (Q0 * Ki -> S0) # 1. wait for Ki if const_expr(stage == 0): mma_kv_consumer_state.advance() pipeline_kv.consumer_wait(mma_kv_consumer_state) Ki_index, Ki_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase # 2. gemm # Don't need to wait for the softmax warp to have finished reading the previous # Si, since this gemm is scheduled after the PV gemm, which guaranteed that Si # has been read and Pi has been written. # sm100_utils.gemm(tiled_mma_qk, tStS[None, None, None, stage], tSrQ[None, None, None, stage], tSrK[None, None, None, Ki_index], zero_init=True) sK_cur = sK[None, None, None, Ki_index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem(sK_cur, Ki_index, Ki_phase) # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index], sB=sK_cur) gemm_Si[stage]( smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sK_cur.iterator) ) # gemm_Si[stage](tCrB=tSrK[None, None, None, Ki_index]) # 3. release S0 / S1 pipeline_s_p_o.producer_commit_w_index(stage) # End of GEMM_QK0i (Q0 * Ki -> S0) # 4. release Ki pipeline_kv.consumer_release(mma_kv_consumer_state) mma_kv_consumer_state.advance() P_full_O_rescaled_phase ^= 1 O_should_accumulate = True # End of seqlen_kv loop # release Q0 & Q1 for stage in cutlass.range(self.q_stage): pipeline_q.consumer_release_w_index(stage) # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) Vi_index, Vi_phase = mma_kv_consumer_state.index, mma_kv_consumer_state.phase tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(self.q_stage): # 2. acquire corrected Oi_partial and Pi pipeline_s_p_o.producer_acquire_w_index_phase(stage, P_full_O_rescaled_phase) # 3. gemm # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) sV_cur = sV[None, None, None, Vi_index] if const_expr(self.uneven_kv_smem): sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) gemm_Pi[stage]( tCrB=tOrVi, sB=sV_cur, # smem_desc_start_b=sm100_desc.make_smem_desc_start_addr(sV_cur.iterator), zero_init=not O_should_accumulate, mbar_ptr=pipeline_p_lastsplit.sync_object_full.get_barrier(stage) if self.split_P_arrive > 0 else None, mbar_phase=P_full_O_rescaled_phase, ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp # has signaled to the correction warps, the softmax warp has just finished # computing the row sum of the current tile. It does not guarantee that the 1st # tile of the next work tile has been computed yet. pipeline_o_acc.producer_commit_w_index(stage) # End of GEMM_PV00 (P0 * V0 -> O0_partial) P_full_O_rescaled_phase ^= 1 # 5. release Vi_end pipeline_kv.consumer_release(mma_kv_consumer_state) mma_kv_consumer_state.advance() # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end # pipeline_s_p_o.producer_acquire_w_index_phase(self.q_stage - 1, P_full_O_rescaled_phase) # We don't need pipeline_o_acc.producer_tail() since we don't call # pipeline_o_acc.producer_acquire() inside the loop. # for both softmax0 and softmax1 warp group @cute.jit def softmax_loop( self, stage: int | Int32, softmax_scale_log2: Float32, softmax_scale: Float32, thr_mma_qk: cute.core.ThrMma, tStS: cute.Tensor, # ((TILE_M, TILE_N), 1, 1, q_stage) sScale: cute.Tensor, mLSE: Optional[cute.Tensor], pipeline_s_p_o: pipeline.PipelineAsync, pipeline_p_lastsplit: pipeline.PipelineAsync, pipeline_sm_stats: pipeline.PipelineAsync, sm_stats_barrier: pipeline.NamedBarrier, pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync], learnable_sink: Optional[cute.Tensor], block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, ): """Compute softmax on attention scores from QK matrix multiplication. This method handles the softmax computation for either the first or second half of the attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum and sum values needed for stable softmax computation, applies optional masking, and transforms raw attention scores into probability distributions. The implementation uses specialized memory access patterns and efficient math operations for computing exp(x) using exp2 functions. It also coordinates pipeline synchronization between MMA, correction, and sequence processing stages. """ tidx = cute.arch.thread_idx()[0] % ( cute.arch.WARP_SIZE # * (len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) * (len(self.softmax0_warp_ids)) ) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1]) tSAcc = tStS[(None, None), 0, 0, stage] # (128, 128) tStScale = cute.composition(tSAcc, cute.make_layout((self.m_block_size, 1))) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tScS = tScS[(None, None), 0, 0] # (128, 128) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width tStP_layout = cute.composition( tSAcc.layout, cute.make_layout((self.m_block_size, tilePlikeFP32)) ) tStP = cute.make_tensor(tSAcc.iterator + self.tmem_s_to_p_offset, tStP_layout) tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype ) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc).get_slice(tidx) tStS_t2r = thr_tmem_load.partition_S(tSAcc) # (((32,32),1),1,4) tmem_store_scale_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), Float32 ) thr_tmem_store_scale = tcgen05.make_tmem_copy(tmem_store_scale_atom, tStScale).get_slice( tidx ) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) # (((16,32),1),1,4) mma_si_consumer_phase = Int32(0) sm_stats_producer_phase = Int32(1) s0_s1_sequence_phase = Int32(1 if stage == 0 else 0) # self.warp_scheduler_barrier_init() warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) mask = AttentionMaskCls(seqlen) shared_mask_kwargs = dict( m_block=(self.q_stage * m_block + stage) * self.cta_group_size, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, mask_local=self.is_local, batch_idx=batch_idx, head_idx=head_idx, aux_tensors=aux_tensors, ) # Recompute fastdiv_mods if necessary recompute_fastdiv_mods_q = cutlass.const_expr( aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) ) recompute_fastdiv_mods_k = cutlass.const_expr( aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) ) if cutlass.const_expr(fastdiv_mods is not None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods fastdiv_mods = ( seqlen_q_divmod if not recompute_fastdiv_mods_q else FastDivmodDivisor(seqlen.seqlen_q), seqlen_k_divmod if not recompute_fastdiv_mods_k else FastDivmodDivisor(seqlen.seqlen_k), ) mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None mask_fn = partial( mask.apply_mask_sm100, mask_mod=mask_mod, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, **shared_mask_kwargs, ) if const_expr(self.use_block_sparsity): # Full blocks dont need mask_mod mask_fn_none = partial( mask.apply_mask_sm100, mask_mod=None, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, **shared_mask_kwargs, ) else: mask_fn_none = None softmax = SoftmaxSm100.create( softmax_scale_log2, rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, softmax_scale=softmax_scale, ) softmax.reset() if const_expr(self.use_block_sparsity): tile_block_count = get_total_block_count( blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) has_work = tile_block_count > Int32(0) else: tile_block_count = n_block_max - n_block_min has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) softmax_step = partial( self.softmax_step, softmax=softmax, thr_mma_qk=thr_mma_qk, pipeline_s_p_o=pipeline_s_p_o, pipeline_p_lastsplit=pipeline_p_lastsplit, pipeline_sm_stats=pipeline_sm_stats, sm_stats_barrier=sm_stats_barrier, pipeline_s0_s1_sequence=pipeline_s0_s1_sequence, thr_tmem_load=thr_tmem_load, thr_tmem_store=thr_tmem_store, thr_tmem_store_scale=thr_tmem_store_scale, tStS_t2r=tStS_t2r, tStScale_r2t=tStScale_r2t, tStP_r2t=tStP_r2t, sScale=sScale, stage=stage, batch_idx=batch_idx, head_idx=head_idx, m_block=(self.q_stage * m_block + stage) * self.cta_group_size, seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, ) if const_expr(self.use_block_sparsity) or has_work: # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) sm_stats_producer_phase ^= 1 # Block sparse or dense iteration if const_expr(self.use_block_sparsity): # When aux_tensors exist, Q indices beyond seqlen_q must be wrapped to avoid # OOB aux_tensor access. Only edge tiles (where m_tile_end > seqlen_q) need this. if const_expr(aux_tensors is not None): m_tile_end = ((self.q_stage * m_block + stage + 1) * self.cta_group_size) * self.m_block_size check_m_boundary = m_tile_end > seqlen.seqlen_q else: check_m_boundary = False ( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, empty_tile, ) = softmax_block_sparse_sm100( blocksparse_tensors, batch_idx, head_idx, m_block, softmax_step, mask_fn, mask_fn_none, mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, pipeline_sm_stats, sm_stats_barrier, self.q_stage, Int32(stage), check_m_boundary, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) if not empty_tile: sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): sScale[ tidx + stage * self.m_block_size + self.q_stage * self.m_block_size ] = softmax.row_max[0] # if tidx == 0: # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. # pipeline_sm_stats.producer_commit_w_index(stage) sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) else: if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block_max - 1, is_first=True, mask_fn=partial(mask_fn, mask_seqlen=True), ) n_block_max -= 1 # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = ( softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), ) ) n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking (but may still need mask_mod) n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( seqlen, m_block, n_block_min ) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): n_block = n_block_max - n_tile - 1 if const_expr(self.mask_mod is not None): mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), ) else: mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, ) # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): n_block = n_block_max - 1 - n_tile mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase = ( softmax_step( mma_si_consumer_phase, sm_stats_producer_phase, s0_s1_sequence_phase, n_block, mask_fn=partial(mask_fn, mask_seqlen=False), ) ) # Now that we no longer already have the 1st iteration, need mask_seqlen=True here # Dense path always writes scale / signals sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] if const_expr(mLSE is not None or learnable_sink is not None): sScale[ tidx + stage * self.m_block_size + self.q_stage * self.m_block_size ] = softmax.row_max[0] # pipeline_sm_stats.producer_commit_w_index(stage) sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # # Write LSE to gmem # if const_expr(mLSE is not None): # acc_O_mn_row_is_zero_or_nan = softmax.row_sum[0] == 0.0 or softmax.row_sum[0] != softmax.row_sum[0] # scale = ( # cute.arch.rcp_approx(softmax.row_sum[0] if not acc_O_mn_row_is_zero_or_nan else 1.0) # ) # LN2 = math.log(2.0) # lse = ( # (softmax.row_max[0] * softmax.scale_log2 + cute.math.log2(softmax.row_sum[0], fastmath=True)) * LN2 # if not acc_O_mn_row_is_zero_or_nan else -Float32.inf # ) # if const_expr(not seqlen.has_cu_seqlens_q): # mLSE_cur = mLSE[None, head_idx, batch_idx] # else: # mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[None, head_idx]) # gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block * 2 + stage,)) # if tidx < seqlen.seqlen_q - (m_block * 2 + stage) * self.m_block_size: # gLSE[tidx] = lse # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop # This is equivalent to pipeline_sm_stats.producer_tail pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) # This is equivalent to pipeline_s0_s1.producer_tail if const_expr(self.s0_s1_barrier): if stage == 0: pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase) @cute.jit def softmax_step( self, mma_si_consumer_phase: Int32, sm_stats_producer_phase: Int32, s0_s1_sequence_phase: Int32, n_block: Int32, softmax: SoftmaxSm100, thr_mma_qk: cute.core.ThrMma, pipeline_s_p_o: pipeline.PipelineAsync, pipeline_p_lastsplit: pipeline.PipelineAsync, pipeline_sm_stats: pipeline.PipelineAsync, sm_stats_barrier: pipeline.NamedBarrier, pipeline_s0_s1_sequence: Optional[pipeline.PipelineAsync], thr_tmem_load: cute.CopyAtom, thr_tmem_store: cute.CopyAtom, thr_tmem_store_scale: cute.CopyAtom, tStS_t2r: cute.Tensor, tStScale_r2t: cute.Tensor, tStP_r2t: cute.Tensor, sScale: cute.Tensor, stage: int | Int32, batch_idx: Int32, head_idx: Int32, m_block: Int32, seqlen, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, mask_fn: Optional[Callable] = None, is_first: bool = False, ) -> Tuple[cute.Int32, cute.Int32, cute.Int32]: """Perform a single step of the softmax computation on a block of attention scores. This method processes one block of the attention matrix, computing numerically stable softmax by first finding the row maximum, subtracting it from all elements, applying exponential function, and then normalizing by the sum of exponentials. It also handles optional masking of attention scores. The method involves several key operations: 1. Loading attention scores from tensor memory 2. Applying optional masking based on position 3. Computing row-wise maximum values for numerical stability 4. Transforming scores using exp2(x*scale - max*scale) 5. Computing row sums for normalization 6. Coordinating pipeline synchronization between different processing stages """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tilePlikeFP32 = self.mma_tiler_qk[1] // Float32.width * self.v_dtype.width tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tScS = tScS[(None, None), 0, 0] # (128, 128) # tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) cta_qk_tiler = (self.mma_tiler_qk[0] // thr_mma_qk.thr_id.shape, self.mma_tiler_qk[1]) tScS_shape = cta_qk_tiler # (128, 128) tScP_shape = (tScS_shape[0], tilePlikeFP32) # (128, 64) # Wait for Si pipeline_s_p_o.consumer_wait_w_index_phase(stage, mma_si_consumer_phase) tSrS_t2r = cute.make_fragment(thr_tmem_load.partition_D(tScS).shape, self.qk_acc_dtype) cute.copy(thr_tmem_load, tStS_t2r, tSrS_t2r) # tSrS_t2r = copy_utils.load_t2r(thr_tmem_load, tScS_shape, tStS_t2r) if cutlass.const_expr(self.score_mod is not None): self.apply_score_mod( tSrS_t2r, thr_tmem_load, thr_mma_qk, batch_idx, head_idx, m_block, n_block, softmax, seqlen, aux_tensors, fastdiv_mods, head_divmod, ) if const_expr(mask_fn is not None): mask_fn(tSrS_t2r, n_block=n_block) row_max, acc_scale = softmax.update_row_max(tSrS_t2r.load(), is_first) if const_expr(not is_first): # tSrScale_r2t = cute.make_fragment(thr_tmem_store_scale.partition_S(tScScale).shape, Float32) # tSrScale_r2t[0] = acc_scale # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) # cute.arch.fence_view_async_tmem_store() thread_idx = thr_tmem_load.thr_idx sScale[thread_idx + stage * self.m_block_size] = acc_scale # if thread_idx == 0: cute.printf("softmax acc_scale stage %d: %f, row_max = %f\n", stage, acc_scale, row_max) # Notify correction wg that row_max is ready # pipeline_sm_stats.producer_commit_w_index(stage) sm_stats_barrier.arrive_w_index(index=stage * 4 + warp_idx) # if thread_idx == 0 and stage == 0: cute.print_tensor(tSrS_t2r) softmax.scale_subtract_rowmax(tSrS_t2r, row_max) # Sequence barrier wait if const_expr(self.s0_s1_barrier): pipeline_s0_s1_sequence.sync_object_full.wait(stage, s0_s1_sequence_phase) tSrP_r2t_f32 = cute.make_fragment( thr_tmem_store.partition_S(cute.make_identity_tensor(tScP_shape)).shape, Float32 ) tSrP_r2t = cute.make_tensor( cute.recast_ptr(tSrP_r2t_f32.iterator, dtype=self.q_dtype), tSrS_t2r.layout ) # softmax.scale_apply_exp2_convert(tSrS_t2r, row_max, tSrP_r2t) softmax.apply_exp2_convert( tSrS_t2r, tSrP_r2t, ex2_emu_freq=self.ex2_emu_freq if const_expr(mask_fn is None) else 0, ex2_emu_start_frg=self.ex2_emu_start_frg, ) # Sequence barrier arrive if const_expr(self.s0_s1_barrier): pipeline_s0_s1_sequence.sync_object_full.arrive(1 - stage, dst=None) # print(tSrP_r2t_f32, tStP_r2t) # cute.copy(thr_tmem_store, tSrP_r2t_f32, tStP_r2t) for i in cutlass.range_constexpr(cute.size(tStP_r2t.shape[2])): cute.copy(thr_tmem_store, tSrP_r2t_f32[None, None, i], tStP_r2t[None, None, i]) if const_expr(self.split_P_arrive > 0): split_P_arrive_idx = cute.size(tStP_r2t.shape[2]) * self.split_P_arrive // self.n_block_size if const_expr(i + 1 == split_P_arrive_idx): # Notify mma warp that the 1st half of P is ready cute.arch.fence_view_async_tmem_store() pipeline_s_p_o.consumer_release_w_index(stage) # Notify mma warp that the 2nd half of P is ready cute.arch.fence_view_async_tmem_store() if const_expr(self.split_P_arrive > 0): cute.arch.sync_warp() with cute.arch.elect_one(): pipeline_p_lastsplit.producer_commit_w_index(stage) else: pipeline_s_p_o.consumer_release_w_index(stage) pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) # acc_scale = cute.math.exp2(acc_scale_, fastmath=True) return mma_si_consumer_phase ^ 1, sm_stats_producer_phase ^ 1, s0_s1_sequence_phase ^ 1 @cute.jit def correction_loop( self, thr_mma_qk: cute.core.ThrMma, thr_mma_pv: cute.core.ThrMma, tStS: cute.Tensor, tOtO: cute.Tensor, sScale: cute.Tensor, mO: cute.Tensor, mLSE: cute.Tensor, sO: cute.Tensor, pipeline_s_p_o: pipeline.PipelineAsync, pipeline_o_acc: pipeline.PipelineAsync, pipeline_sm_stats: pipeline.PipelineAsync, sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, learnable_sink: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, softmax_scale_log2: Float32, block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 mma_tile_coord_v = thr_mma_qk.thr_idx tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) tStScale_layout = cute.composition(tStS.layout, cute.make_layout((self.m_block_size, 1))) tStScales = tuple( cute.make_tensor(tStS.iterator + self.tmem_vec_offset[stage], tStScale_layout) for stage in range(self.q_stage) ) tScScale = cute.composition(tScS, cute.make_layout((self.m_block_size, 1))) tmem_load_v_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(1)), self.qk_acc_dtype ) thr_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStScales[0]).get_slice(tidx) tStScales_t2r = [thr_tmem_load_vec.partition_S(tStScales[stage]) for stage in range(self.q_stage)] tSrScale_t2r_shape = thr_tmem_load_vec.partition_D(tScScale).shape # First iter: no correction is required # Notify mma warp that O has been rescaled for stage in cutlass.range(self.q_stage): pipeline_s_p_o.consumer_release_w_index(stage) sm_stats_consumer_phase = Int32(0) o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) gO = layout_utils.select( cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] ) # (128, 128, 2) gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage if const_expr(self.use_block_sparsity): total_block_count = get_total_block_count( blocksparse_tensors, batch_idx, head_idx, m_block, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) has_work = total_block_count > Int32(0) else: total_block_count = n_block_max - n_block_min has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) if has_work: # Ignore first signal from softmax as no correction is required # pipeline_sm_stats.consumer_wait_w_index_phase(0, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=0 * 4 + warp_idx) pipeline_sm_stats.consumer_release_w_index(0) if const_expr(self.q_stage == 2): # pipeline_sm_stats.consumer_wait_w_index_phase(1, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=1 * 4 + warp_idx) sm_stats_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) for i in cutlass.range(total_block_count - 1, unroll=1): for stage in cutlass.range_constexpr(self.q_stage): # wait for S0 / S1 # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] scale = sScale[tidx + stage * self.m_block_size] should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 # should_rescale = True # if tidx == 0: cute.printf("Correction scale i = %d, for stage %d: %f, should_rescale = %d\n", i, stage, scale, should_rescale) # Don't need O_full anymore, since by the time softmax has signaled the correction # warps, S_i must have been done, so O_i-1 must have been done as well. # pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if should_rescale: self.correction_rescale(thr_mma_pv, tOtO[None, None, None, stage], tidx, scale) # Notify mma warp that O has been rescaled pipeline_s_p_o.consumer_release_w_index(stage) pipeline_sm_stats.consumer_release_w_index(self.q_stage - 1 - stage) sm_stats_consumer_phase ^= 1 # o_corr_consumer_phase ^= 1 if const_expr(self.q_stage == 2): pipeline_sm_stats.consumer_release_w_index(1) # End of seqlen_corr_loop_steps # Even in the case of self.overlap_sO_sQ, we can write to stage 0 of sO without # additional sync because the MMA in the top half must have been done. # Similarly we can write to stage 1 of sO without additional sync. learnable_sink_val = [None] * self.q_stage if const_expr(learnable_sink is not None): if const_expr(not self.pack_gqa): sink_val = Float32(learnable_sink[head_idx]) learnable_sink_val = [sink_val] * self.q_stage else: # Each thread might have a different sink value due to different q_head for stage in cutlass.range_constexpr(self.q_stage): q_head_idx = ( ((m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v) * self.m_block_size + tidx ) % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead learnable_sink_val[stage] = Float32(learnable_sink[q_head_idx]) for stage in cutlass.range_constexpr(self.q_stage): # pipeline_sm_stats.consumer_wait_w_index_phase(stage, sm_stats_consumer_phase) sm_stats_barrier.arrive_and_wait_w_index(index=stage * 4 + warp_idx) # cute.copy(tiled_tmem_load_vec, tStScales_t2r[stage], tSrScale_t2r) # cute.arch.fence_view_async_tmem_load() # scale = tSrScale_t2r[0] row_sum = sScale[tidx + stage * self.m_block_size] if const_expr(mLSE is not None or learnable_sink is not None): row_max = sScale[tidx + stage * self.m_block_size + self.q_stage * self.m_block_size] else: row_max = None pipeline_sm_stats.consumer_release_w_index(stage) if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) sink_val = learnable_sink_val[stage] if const_expr(not self.is_split_kv) or split_idx == 0: if row_max == -Float32.inf: # It's possible to have an empty row with splitKV. row_max = sink_val * (LOG2_E / softmax_scale_log2) row_sum = Float32(1.0) else: row_sum += cute.math.exp2( sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) # Wait for the last O to be ready from the MMA warp pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi.producer_acquire_w_index_phase(stage, corr_epi_producer_phase) self.correction_epilogue( thr_mma_pv, tOtO[None, None, None, stage], tidx, stage, m_block, seqlen.seqlen_q, scale, sO[None, None, stage], mO_cur, gO[None, None, stage], gmem_tiled_copy_O, ) # Signal for the next work tile that O buffers in tmem are already read, so # mma warp can write to them pipeline_s_p_o.consumer_release_w_index(stage) if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi.producer_commit_w_index(stage) # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) o_corr_consumer_phase ^= 1 sm_stats_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 else: gmem_tiled_copy_O_for_empty_tile = None if const_expr(self.use_correction_warps_for_epi): gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O if const_expr(self.use_block_sparsity): ( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, ) = handle_block_sparse_empty_tile_correction_sm100( tidx, self.q_stage, self.m_block_size, self.qhead_per_kvhead, self.pack_gqa, self.is_split_kv, learnable_sink, mLSE, seqlen, m_block, head_idx, batch_idx, split_idx, sScale, stats, self.correction_epilogue, thr_mma_pv, tOtO, sO, pipeline_sm_stats, sm_stats_barrier, pipeline_o_epi, sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, softmax_scale_log2, mO_cur, gO, gmem_tiled_copy_O_for_empty_tile, ) if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): if const_expr(self.is_split_kv): mLSE_cur = mLSE[None, head_idx, batch_idx, split_idx] else: mLSE_cur = mLSE[None, head_idx, batch_idx] else: offset = ( seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) ) if const_expr(self.is_split_kv): mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx, split_idx]) else: mLSE_cur = cute.domain_offset((offset,), mLSE[None, head_idx]) for stage in cutlass.range_constexpr(self.q_stage): m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_tile_idx,)) row_sum, row_max, acc_O_mn_row_is_zero_or_nan = stats[stage] # if tidx == 0 and stage <= 1: # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) seqlen_q = ( seqlen.seqlen_q if const_expr(not self.pack_gqa) else seqlen.seqlen_q * self.qhead_per_kvhead ) if tidx < seqlen_q - m_tile_idx * self.m_block_size: # This actually just works with PackGQA too gLSE[tidx] = lse # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps if const_expr(not self.use_correction_warps_for_epi): pipeline_o_epi.producer_acquire_w_index_phase(self.q_stage - 1, corr_epi_producer_phase) @cute.jit def correction_rescale( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, tidx: Int32, scale: Float32, ): """Rescale intermediate attention results based on softmax normalization factor. This method performs a crucial correction step in the attention computation pipeline. When processing attention in blocks, the softmax normalization factors may change as new blocks are processed. This method rescales previously computed partial output values to account for updated normalization factors. The implementation uses efficient tensor memory operations to: 1. Load existing partial attention output from tensor memory 2. Apply the scaling factor to all elements 3. Store the rescaled results back to tensor memory """ tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) corr_tile_size = 16 # tuneable parameter tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype ) tmem_store_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype, ) tOtO_i = cute.composition(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) tOcO_i = cute.composition(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) thr_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i).get_slice(tidx) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i).get_slice(tidx) tOtO_t2r = thr_tmem_load.partition_S(tOtO_i) tOrO_t2r_shape = thr_tmem_load.partition_D(tOcO_i).shape tOtO_r2t = thr_tmem_store.partition_D(tOtO_i) frg_count = self.head_dim_v_padded // corr_tile_size tOrO_frg = cute.make_fragment((tOrO_t2r_shape, frg_count), self.pv_acc_dtype) for i in cutlass.range_constexpr(frg_count): tOrO_frg = cute.make_fragment(tOrO_t2r_shape, self.pv_acc_dtype) tOtO_t2r_i = cute.make_tensor(tOtO_t2r.iterator + i * corr_tile_size, tOtO_t2r.layout) cute.copy(thr_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale) ) tOtO_r2t_i = cute.make_tensor(tOtO_r2t.iterator + i * corr_tile_size, tOtO_r2t.layout) cute.copy(thr_tmem_store, tOrO_frg, tOtO_r2t_i) cute.arch.fence_view_async_tmem_store() @cute.jit def correction_epilogue( self, thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, tidx: Int32, stage: Int32, m_block: Int32, seqlen_q: Int32, scale: Float32, sO: cute.Tensor, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Apply final scaling and transformation to attention output before writing to global memory. This correction_epilogue function handles the final processing step for attention output values. It applies a scaling factor to the accumulated attention results and prepares the data for efficient transfer back to global memory. The method performs: 1. Loading of accumulated attention results from tensor memory 2. Application of the final output scaling factor 3. Type conversion if necessary (typically from higher precision accumulator to output precision) 4. Reorganization of data for optimal memory access patterns 5. Preparation for efficient TMA store operations :param thr_mma: Thread MMA operation for the computation :type thr_mma: cute.core.ThrMma :param tOtO: Tensor containing accumulated attention output :type tOtO: cute.Tensor :param scale: Final scaling factor to apply to the output :type scale: Float32 :param sO: Shared memory tensor for the final output :type sO: cute.Tensor """ corr_tile_size = 8 * 32 // self.o_dtype.width # Use CTA 0 mapping for smem partitioning since sO is per-CTA sized tOsO = thr_mma.get_slice(0).partition_C(sO) tOcO = thr_mma.partition_C(cute.make_identity_tensor(self.mma_tiler_pv[:2])) tOtO_i = cute.logical_divide(tOtO, cute.make_layout((self.m_block_size, corr_tile_size))) tOcO_i = cute.logical_divide(tOcO, cute.make_layout((self.m_block_size, corr_tile_size))) tOsO_i = cute.logical_divide(tOsO, cute.make_layout((self.m_block_size, corr_tile_size))) epi_subtile = (self.epi_tile[0], corr_tile_size) tmem_copy_atom = sm100_utils_basic.get_tmem_load_op( self.mma_tiler_pv, self.o_layout, self.o_dtype, self.pv_acc_dtype, epi_subtile, use_2cta_instrs=self.use_2cta_instrs, ) tiled_tmem_load = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i[(None, None), 0]) thr_tmem_load = tiled_tmem_load.get_slice(tidx) smem_copy_atom = sm100_utils_basic.get_smem_store_op( self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load ) tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) tOtO_t2r = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) tOsO_s2r = copy_utils.partition_D_position_independent(thr_tmem_load, tOsO_i[(None, None), None]) tOcO_t2r = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) for i in cutlass.range(self.head_dim_v_padded // corr_tile_size, unroll_full=True): tOtO_t2r_i = tOtO_t2r[None, 0, 0, i] tOsO_r2s_i = tOsO_s2r[None, 0, 0, i] tOrO_frg = cute.make_fragment(tOcO_t2r[None, 0, 0, i].shape, self.pv_acc_dtype) cute.copy(tiled_tmem_load, tOtO_t2r_i, tOrO_frg) for j in cutlass.range(0, cute.size(tOrO_frg), 2, unroll_full=True): tOrO_frg[j], tOrO_frg[j + 1] = cute.arch.mul_packed_f32x2( (tOrO_frg[j], tOrO_frg[j + 1]), (scale, scale) ) copy_utils.cvt_copy(tiled_smem_store, tOrO_frg, tOsO_r2s_i) cute.arch.fence_view_async_shared() if const_expr(self.use_correction_warps_for_epi): assert(not self.use_tma_O) assert(gmem_tiled_copy_O is not None) cute.arch.barrier(barrier_id=int(NamedBarrierFwdSm100.Epilogue), number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) mma_tile_coord_v = thr_mma.thr_idx m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v self._store_O_to_gmem( sO, gO, mO_cur, gmem_tiled_copy_O, tidx, seqlen_q, m_tile_idx ) @cute.jit def _store_O_to_gmem( self, sO_stage: cute.Tensor, gO: cute.Tensor, mO_cur: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tidx: Int32, seqlen_q: Int32, m_tile_idx: Int32, ): """Copy a single stage of O from smem to gmem via registers.""" gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) tOsO = gmem_thr_copy_O.partition_S(sO_stage) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) tOgO = gmem_thr_copy_O.partition_D(gO) tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = copy_utils.predicate_k(tOcO, limit=mO_cur.shape[1]) pack_gqa = PackGQA( self.m_block_size, self.head_dim_v_padded, self.check_hdim_v_oob, self.qhead_per_kvhead, ) # load acc O from smem to rmem for wider vectorization tOrO = cute.make_fragment_like(tOsO, self.o_dtype) cute.autovec_copy(tOsO, tOrO) # copy acc O from rmem to gmem if const_expr(not self.pack_gqa): for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): if ( t0OcO[0, rest_m, 0][0] < seqlen_q - m_tile_idx * self.m_block_size - tOcO[0][0] ): cute.copy( gmem_tiled_copy_O, tOrO[None, rest_m, None], tOgO[None, rest_m, None], pred=tOpO[None, rest_m, None] if const_expr(self.check_hdim_v_oob) else None, ) else: pack_gqa.store_O( mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_tile_idx, seqlen_q ) @cute.jit def epilogue_s2g( self, mO: cute.Tensor, sO: cute.Tensor, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], pipeline_o_epi: pipeline.PipelineAsync, block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, mma_tile_coord_v: Int32 = 0, ): epi_consumer_phase = Int32(0) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) if const_expr(not self.is_split_kv) or n_block_min < n_block_max: if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] tiler_gO = ((self.mma_tiler_pv[0] * self.q_stage), self.head_dim_v_padded) gO = cute.local_tile(mO_cur, tiler_gO, (m_block, 0)) # (128 * 2, 128) gO = layout_utils.select( cute.flat_divide(gO, (self.mma_tiler_pv[0],)), mode=[0, 2, 1] ) # (128, 128, 2) gO = cute.flat_divide(gO, (self.mma_tiler_pv[0] // self.cta_group_size,))[None, mma_tile_coord_v, None, None] if const_expr(self.use_tma_O): store_O, _, _ = copy_utils.tma_get_copy_fn( tma_atom_O, 0, cute.make_layout(1), sO, gO ) for stage in cutlass.range(self.q_stage, unroll_full=True): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem store_O(src_idx=stage, dst_idx=stage) cute.arch.cp_async_bulk_commit_group() for stage in cutlass.range_constexpr(self.q_stage): # Ensure O0 / O1 buffer is ready to be released cute.arch.cp_async_bulk_wait_group(self.q_stage - 1 - stage, read=True) pipeline_o_epi.consumer_release_w_index(stage) else: tidx = cute.arch.thread_idx()[0] % ( cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) ) for stage in cutlass.range_constexpr(self.q_stage): # wait from corr, issue tma store on smem # 1. wait for O0 / O1 final pipeline_o_epi.consumer_wait_w_index_phase(stage, epi_consumer_phase) # 2. copy O0 / O1 to gmem m_tile_idx = (m_block * self.q_stage + stage) * self.cta_group_size + mma_tile_coord_v self._store_O_to_gmem( sO[None, None, stage], gO[None, None, stage], mO_cur, gmem_tiled_copy_O, tidx, seqlen.seqlen_q, m_tile_idx, ) pipeline_o_epi.consumer_release_w_index(stage) epi_consumer_phase ^= 1 # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() def load_Q( self, load_Q_fn: Callable, pipeline_q: pipeline.PipelineAsync, block: Int32, stage: int, phase: Int32, ): pipeline_q.producer_acquire_w_index_phase(stage, phase) load_Q_fn(src_idx=block, dst_idx=stage, tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(stage)) @cute.jit def load_KV( self, tma_atom: Optional[cute.CopyAtom], tXgX: Optional[cute.Tensor], tXsX: Optional[cute.Tensor], paged_kv_manager: Optional[PagedKVManager], sX: cute.Tensor, block: Int32, pipeline_kv: pipeline.PipelineAsync, producer_state: pipeline.PipelineState, K_or_V: Literal["K", "V"], page_idx: Optional[Int32] = None, extra_tx_count: Optional[Int32] = None, ): assert K_or_V in ("K", "V") stage, phase = producer_state.index, producer_state.phase extra_tx_count_kv = self.tma_copy_bytes[K_or_V] - self.tma_copy_bytes["K"] extra_tx_count = ( extra_tx_count_kv + (extra_tx_count if extra_tx_count is not None else 0) if const_expr(self.use_tma_KV) else None ) extra_kwargs = {"extra_tx_count": extra_tx_count} if const_expr(self.use_tma_KV) else {} pipeline_kv.producer_acquire(producer_state, **extra_kwargs) if const_expr(K_or_V == "K" and self.uneven_kv_smem): # Before this round, the smem location was occupied by V, which is smaller than # K. So we need to wait for the stage after that (stage 1) to be empty as well. if stage == 0: pipeline_kv.sync_object_empty.wait(1, phase) if const_expr(self.use_tma_KV): assert tXgX is not None and tXsX is not None and tma_atom is not None tXsX_cur = tXsX[None, stage] if const_expr(self.uneven_kv_smem): # Since this is the producer_state, the phase starts at 1, so we have to invert it tXsX_cur = self.offset_kv_smem(tXsX_cur, stage, phase ^ 1) # Currently we assume that page_size == n_block_size so we index into tXgX with block = 0 tXgX_cur = tXgX[None, block] if const_expr(page_idx is None) else tXgX[None, 0, page_idx] cute.copy(tma_atom, tXgX_cur, tXsX_cur, tma_bar_ptr=pipeline_kv.producer_get_barrier(producer_state)) else: assert paged_kv_manager is not None assert extra_tx_count is None sX_cur = sX[None, None, None, stage] if const_expr(self.uneven_kv_smem): sX_cur = self.offset_kv_smem(sX_cur, stage, phase ^ 1) paged_kv_manager.load_KV(block, sX_cur, K_or_V) cute.arch.cp_async_commit_group() pipeline_kv.sync_object_full.arrive_cp_async_mbarrier(stage) @cute.jit def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): if const_expr(self.uneven_kv_smem): # smem layout is [smem_large, smem_small, smem_large], and the current stride is # (smem_large + smem_small) // 2. So for stage == 1, move right by offset if # phase == 0, or left by offset if phase == 1. offset = 0 if stage != 1 else self.uneven_kv_smem_offset * (1 - 2 * phase) # Hint that the offset is 128-bit aligned so that # ptr + offset preserves the alignment needed by cp.async. offset = cute.assume(offset, divby=128 // self.k_dtype.width) return cute.make_tensor(sX.iterator + offset, sX.layout) else: return sX # @cute.jit # def warp_scheduler_barrier_init(self): # warp_group_idx = utils.canonical_warp_group_idx(sync=False) # if warp_group_idx == 0: # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1), number_of_threads=2 * 128, # ) # def warp_scheduler_barrier_sync(self): # cute.arch.barrier( # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + utils.canonical_warp_group_idx(sync=False), # number_of_threads=2 * 128 # ) # def warp_scheduler_barrier_arrive(self): # cur_wg = utils.canonical_warp_group_idx(sync=False) # next_wg = 1 - cur_wg # cute.arch.barrier_arrive( # barrier_id=int(NamedBarrierFwdSm100.WarpSchedulerWG1) + next_wg, number_of_threads=2 * 128, # ) @cute.jit def apply_score_mod( self, tSrS_t2r, thr_tmem_load, thr_mma_qk, batch_idx, head_idx, m_block, n_block, softmax, seqlen: SeqlenInfoQK, aux_tensors=None, fastdiv_mods=(None, None), head_divmod=None, ): """Apply score modification for SM100 (constant q_idx).""" # Prepare index tensor with extra partition cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size)) cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS) tScS = thr_mma_qk.partition_C(cS) tScS = tScS[(None, None), 0, 0] tScS_t2r = thr_tmem_load.partition_D(tScS) # Shared q_idx for all scores q_idx_logical = tScS_t2r[0][0] # For Pack-GQA, compute the logical head index for this tile if cutlass.const_expr(self.pack_gqa): assert head_divmod is not None # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) q_physical = q_idx_logical q_idx_logical, head_offset = divmod(q_physical, head_divmod) head_idx = head_idx * self.qhead_per_kvhead + head_offset if cutlass.const_expr(aux_tensors is not None): seqlen_q_divmod, _ = fastdiv_mods _, q_idx_logical = divmod(q_idx_logical, seqlen_q_divmod) apply_score_mod_inner( tSrS_t2r, tScS_t2r, self.score_mod, batch_idx, head_idx, softmax.softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info=seqlen, constant_q_idx=q_idx_logical, qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1, ) ================================================ FILE: flash_attn/cute/flash_fwd_sm120.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # SM120 (Blackwell GeForce / DGX Spark) forward pass. # # SM120 uses the same SM80-era MMA instructions (mma.sync.aligned.m16n8k16) but has # a smaller shared memory capacity (99 KB vs 163 KB on SM80). This module subclasses # FlashAttentionForwardSm80 and overrides the SMEM capacity check accordingly. import cutlass import cutlass.utils as utils_basic from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): # Keep arch = 80 to use CpAsync code paths (no TMA for output). # The compilation target is determined by the GPU at compile time, not this field. arch = 80 @staticmethod def can_implement( dtype, head_dim, head_dim_v, tile_m, tile_n, num_stages, num_threads, is_causal, Q_in_regs=False, ) -> bool: """Check if the kernel can be implemented on SM120. Same logic as SM80 but uses SM120's shared memory capacity (99 KB). """ if dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if head_dim % 8 != 0: return False if head_dim_v % 8 != 0: return False if tile_n % 16 != 0: return False if num_threads % 32 != 0: return False # Shared memory usage: Q tile + (K tile + V tile) smem_usage_Q = tile_m * head_dim * 2 smem_usage_K = tile_n * head_dim * num_stages * 2 smem_usage_V = tile_n * head_dim_v * num_stages * 2 smem_usage_QV = ( (smem_usage_Q + smem_usage_V) if not Q_in_regs else max(smem_usage_Q, smem_usage_V) ) smem_usage = smem_usage_QV + smem_usage_K # SM120 has 99 KB shared memory (vs 163 KB on SM80) smem_capacity = utils_basic.get_smem_capacity_in_bytes("sm_120") if smem_usage > smem_capacity: return False if (tile_m * 2) % num_threads != 0: return False return True ================================================ FILE: flash_attn/cute/flash_fwd_sm90.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # SM90 (Hopper) forward pass for flash attention, extracted from flash_fwd.py. from types import SimpleNamespace from typing import Callable, Literal, Optional from functools import partial import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, const_expr from cutlass.cute.nvgpu import cpasync, warpgroup from cutlass.utils import LayoutEnum import cutlass.utils.hopper_helpers as sm90_utils_basic from cutlass import pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait from cutlass.base_dsl.arch import Arch from quack import copy_utils from quack import layout_utils from quack import sm90_utils from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.block_sparsity import BlockSparseTensors from flash_attn.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) from flash_attn.cute import pipeline as pipeline_custom from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom from flash_attn.cute.paged_kv import PagedKVManager from flash_attn.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ) from cutlass.cute import FastDivmodDivisor from flash_attn.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): def __init__( self, *args, intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, paged_kv_non_tma: bool = False, **kwargs, ): super().__init__(*args, **kwargs) self.intra_wg_overlap = intra_wg_overlap self.mma_pv_is_rs = mma_pv_is_rs self.buffer_align_bytes = 1024 self.use_tma_KV = not paged_kv_non_tma assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" ) self.cluster_shape_mn = (1, 1) assert self.arch >= Arch.sm_90 and self.arch <= Arch.sm_90a, "Only SM 9.x is supported" def _get_smem_layout_atom(self): sQ_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom(LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdim), self.dtype, ) sK_layout_atom = sQ_layout_atom sV_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_hdimv ), self.dtype, ) sO_layout_atom = sV_layout_atom if not self.mma_pv_is_rs: sP_layout_atom = warpgroup.make_smem_layout_atom( sm90_utils_basic.get_smem_layout_atom( LayoutEnum.ROW_MAJOR, self.dtype, self.tile_n ), self.dtype, ) else: sP_layout_atom = None return sQ_layout_atom, sK_layout_atom, sV_layout_atom, sO_layout_atom, sP_layout_atom def _get_tiled_mma(self): tiled_mma_qk = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.K, Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), tiler_mn=(64, self.tile_n), ) tiled_mma_pv = sm90_utils_basic.make_trivial_tiled_mma( self.dtype, self.dtype, warpgroup.OperandMajorMode.K, warpgroup.OperandMajorMode.MN, Float32, atom_layout_mnk=(self.tile_m // 64, 1, 1), # Might need (1, 2, 1) for hdim 512 tiler_mn=(64, self.tile_hdimv), a_source=warpgroup.OperandSource.RMEM if self.mma_pv_is_rs else warpgroup.OperandSource.SMEM, ) return tiled_mma_qk, tiled_mma_pv def _get_shared_storage_cls(self): sQ_struct, sK_struct, sV_struct = [ cute.struct.Align[ cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes ] for layout in (self.sQ_layout, self.sK_layout, self.sV_layout) ] cosize_sQV = max(cute.cosize(self.sQ_layout), cute.cosize(self.sV_layout)) sQV_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sQV], 1024] cosize_sP = cute.cosize(self.sP_layout) if const_expr(self.sP_layout is not None) else 0 sP_struct = cute.struct.Align[cute.struct.MemRange[self.dtype, cosize_sP], 1024] # 1 stage * 2 for Q pipeline (full + empty), self.num_stages*2 for K, self.num_stages*2 for V, mbar_ptr_Q_struct = cute.struct.MemRange[cutlass.Int64, 1 * 2] mbar_ptr_K_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] mbar_ptr_V_struct = cute.struct.MemRange[cutlass.Int64, self.num_stages * 2] @cute.struct class SharedStorageQKV: mbar_ptr_Q: mbar_ptr_Q_struct mbar_ptr_K: mbar_ptr_K_struct mbar_ptr_V: mbar_ptr_V_struct sV: sV_struct sQ: sQ_struct sK: sK_struct sP: sP_struct @cute.struct class SharedStorageSharedQV: mbar_ptr_Q: mbar_ptr_Q_struct mbar_ptr_K: mbar_ptr_K_struct mbar_ptr_V: mbar_ptr_V_struct sQ: sQV_struct sK: sK_struct sP: sP_struct return SharedStorageQKV if const_expr(not self.Q_in_regs) else SharedStorageSharedQV @cute.jit def __call__( self, mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], softmax_scale: Float32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mPageTable: Optional[cute.Tensor] = None, # (b_k, max_num_pages_per_seq) window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). stream: cuda.CUstream = None, ): """Configures and launches the flash attention kernel. mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ self._check_type( *( t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) ) ) self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = ( layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) else None ) tiled_mma_qk, tiled_mma_pv = self._get_tiled_mma() self.num_mma_threads = tiled_mma_qk.size self.num_threads_per_warp_group = 128 self.num_wg_mma = self.num_mma_threads // self.num_threads_per_warp_group assert self.num_wg_mma in [1, 2, 3] self.num_threads = self.num_threads_per_warp_group * (self.num_wg_mma + 1) self.num_producer_threads = 32 self.num_Q_load_threads = self.num_mma_threads # If not TMA_Q, MMA threads load Q self.num_epilogue_threads = self.num_mma_threads self.num_mma_regs, self.num_producer_regs = {1: (256, 56), 2: (240, 24), 3: (160, 32)}[ self.num_wg_mma ] self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) self.use_scheduler_barrier = ( (self.num_wg_mma >= 2 and self.tile_hdim <= 128) if const_expr(self.intra_wg_overlap) else (self.num_wg_mma == 2) ) self.use_tma_Q = self.arch >= Arch.sm_90 and not ( self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 ) self.use_tma_O = self.use_tma_Q self.rescale_O_before_gemm = self.tile_hdimv > 128 and self.intra_wg_overlap self._setup_attributes() # TODO: we prob don't need most of what's in _setup_attributes self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout = [ sm90_utils.make_smem_layout(mX.element_type, LayoutEnum.ROW_MAJOR, shape, stage) for mX, shape, stage in [ (mQ, (self.tile_m, self.tile_hdim), None), (mK, (self.tile_n, self.tile_hdim), self.num_stages), (mV, (self.tile_n, self.tile_hdimv), self.num_stages), (mO, (self.tile_m, self.tile_hdimv), None), ] ] self.sP_layout = None if const_expr(not self.mma_pv_is_rs): self.sP_layout = sm90_utils.make_smem_layout( mV.element_type, LayoutEnum.ROW_MAJOR, (self.tile_m, self.tile_n) ) SharedStorage = self._get_shared_storage_cls() mQ_og, mO_og = mQ, mO if const_expr(self.pack_gqa): nheads_kv = mK.shape[2] mQ = pack_gqa_layout(mQ, self.qhead_per_kvhead, nheads_kv, head_idx=2) mO = pack_gqa_layout(mO, self.qhead_per_kvhead, nheads_kv, head_idx=2) if const_expr(mLSE is not None): mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, nheads_kv, head_idx=1) # TMA gmem_tiled_copy_Q = cpasync.CopyBulkTensorTileG2SOp() gmem_tiled_copy_KV = cpasync.CopyBulkTensorTileG2SOp() # Might multicast gmem_tiled_copy_O = cpasync.CopyBulkTensorTileS2GOp() self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1])) for name, mX, layout in [ ("Q", mQ, self.sQ_layout), ("K", mK, self.sK_layout), ("V", mV, self.sV_layout), ] } make_tiled_tma_atom_fn = ( partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) if const_expr(self.pack_gqa) else cpasync.make_tiled_tma_atom ) tma_atom_Q, tma_tensor_Q = None, None if const_expr(self.use_tma_Q): tma_atom_Q, tma_tensor_Q = make_tiled_tma_atom_fn( gmem_tiled_copy_Q, mQ_og if const_expr(self.pack_gqa) else mQ, self.sQ_layout, (self.tile_m, self.tile_hdim), # No mcast ) tma_atom_K, tma_tensor_K = None, None tma_atom_V, tma_tensor_V = None, None if const_expr(self.use_tma_KV): tma_atom_K, tma_tensor_K = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mK, cute.select(self.sK_layout, mode=[0, 1]), (self.tile_n, self.tile_hdim), 1, # No mcast for now ) tma_atom_V, tma_tensor_V = cpasync.make_tiled_tma_atom( gmem_tiled_copy_KV, mV, cute.select(self.sV_layout, mode=[0, 1]), (self.tile_n, self.tile_hdimv), 1, # No mcast for now ) tma_atom_O, tma_tensor_O = None, None if const_expr(self.use_tma_O): mO_tma = mO_og if const_expr(self.pack_gqa) else mO if const_expr(self.varlen_q): mO_tma = copy_utils.create_ragged_tensor_for_tma( mO_tma, ragged_dim=0, ptr_shift=True ) tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( gmem_tiled_copy_O, mO_tma, self.sO_layout, (self.tile_m, self.tile_hdimv), # No mcast ) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler else: TileScheduler = ( SingleTileScheduler if const_expr(not self.is_causal or self.is_local) else SingleTileLPTScheduler ) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), self.tile_m), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), 1, # num_splits cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mQ.shape[1], mV.shape[1], total_q=cute.size(mQ.shape[0]) if const_expr(mCuSeqlensQ is not None) else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), tile_shape_mn=(self.tile_m, self.tile_n), mCuSeqlensQ=mCuSeqlensQ, mSeqUsedQ=mSeqUsedQ, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, element_size=self.dtype.width // 8, is_persistent=False, lpt=self.is_causal or self.is_local, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2( softmax_scale, self.score_mod ) window_size_left = Int32(window_size_left) if window_size_left is not None else None window_size_right = Int32(window_size_right) if window_size_right is not None else None fastdiv_mods = utils.compute_fastdiv_mods( mQ, mK, self.qhead_per_kvhead, self.pack_gqa, aux_tensors, mPageTable ) self.kernel( tma_tensor_Q if const_expr(self.use_tma_Q) else mQ, tma_tensor_K if const_expr(self.use_tma_KV) else mK, tma_tensor_V if const_expr(self.use_tma_KV) else mV, tma_tensor_O if const_expr(self.use_tma_O) else mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK, mPageTable, tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O, softmax_scale_log2, softmax_scale, window_size_left, window_size_right, learnable_sink, blocksparse_tensors, self.sQ_layout, self.sK_layout, self.sV_layout, self.sO_layout, self.sP_layout, self.gmem_tiled_copy_Q, self.gmem_tiled_copy_K, self.gmem_tiled_copy_V, self.gmem_tiled_copy_O, tiled_mma_qk, tiled_mma_pv, tile_sched_params, TileScheduler, SharedStorage, aux_tensors, fastdiv_mods, ).launch( grid=grid_dim, block=[self.num_threads, 1, 1], stream=stream, min_blocks_per_mp=1, ) @cute.kernel def kernel( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], mSeqUsedK: Optional[cute.Tensor], mPageTable: Optional[cute.Tensor], tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], tma_atom_O: Optional[cute.CopyAtom], softmax_scale_log2: Float32, softmax_scale: Optional[Float32], window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, sV_layout: cute.ComposedLayout, sO_layout: cute.ComposedLayout, sP_layout: cute.ComposedLayout | None, gmem_tiled_copy_Q: cute.TiledCopy, gmem_tiled_copy_K: cute.TiledCopy, gmem_tiled_copy_V: cute.TiledCopy, gmem_tiled_copy_O: cute.TiledCopy, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], aux_tensors=Optional[list[cute.Tensor]], fastdiv_mods=None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # Prefetch tma descriptor if warp_idx == 0: for tma_atom in (tma_atom_Q, tma_atom_K, tma_atom_V, tma_atom_O): if const_expr(tma_atom is not None): cpasync.prefetch_descriptor(tma_atom) smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) # Mbarrier / pipeline init mbar_ptr_Q = storage.mbar_ptr_Q.data_ptr() if const_expr(not self.use_tma_Q): if warp_idx == 1: cute.arch.mbarrier_init(mbar_ptr_Q, self.num_Q_load_threads) ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) tma_warp = ThreadCooperativeGroup(1) load_threads = ThreadCooperativeGroup(self.num_threads_per_warp_group) mma_warps = ThreadCooperativeGroup(self.num_mma_threads // cute.arch.WARP_SIZE) mma_threads = ThreadCooperativeGroup(self.num_mma_threads) pipeline_q = None if const_expr(self.use_tma_Q): pipeline_q = pipeline_custom.PipelineTmaAsync.create( barrier_storage=mbar_ptr_Q, num_stages=1, producer_group=tma_warp, consumer_group=mma_warps, tx_count=self.tma_copy_bytes["Q"], defer_sync=True, ) # We rely on pipeline_k and pipeline_v to initialize the mbarrier fence and sync if const_expr(self.use_tma_KV): # PipelineTmaAsync: consumer_release has internal per-warp gating # (is_signalling_thread), so arrive count = num_consumer_warps pipeline_k = pipeline_custom.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_K.data_ptr(), num_stages=self.num_stages, producer_group=tma_warp, consumer_group=mma_warps, tx_count=self.tma_copy_bytes["K"], defer_sync=True, ) pipeline_v = pipeline_custom.PipelineTmaAsync.create( barrier_storage=storage.mbar_ptr_V.data_ptr(), num_stages=self.num_stages, producer_group=tma_warp, consumer_group=mma_warps, tx_count=self.tma_copy_bytes["V"], defer_sync=True, ) else: # PipelineAsync: no thread gating in producer_commit/consumer_release, # so arrive counts must equal actual thread counts pipeline_k = pipeline.PipelineAsync.create( num_stages=self.num_stages, producer_group=load_threads, consumer_group=mma_threads, barrier_storage=storage.mbar_ptr_K.data_ptr(), defer_sync=True, ) pipeline_v = pipeline.PipelineAsync.create( num_stages=self.num_stages, producer_group=load_threads, consumer_group=mma_threads, barrier_storage=storage.mbar_ptr_V.data_ptr(), defer_sync=True, ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) # /////////////////////////////////////////////////////////////////////////////// # Get shared memory buffer # /////////////////////////////////////////////////////////////////////////////// sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner) if const_expr(not self.Q_in_regs): sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner) else: sV = storage.sQ.get_tensor( sV_layout.outer, swizzle=sV_layout.inner, dtype=mV.element_type ) # Transpose view of V to tensor with layout (head_dim_v, tile_n) for tiled mma sVt = layout_utils.transpose_view(sV) sP = None if const_expr(sP_layout is not None): sP = storage.sP.get_tensor(sP_layout.outer, swizzle=sP_layout.inner) # reuse sQ's data iterator sO = storage.sQ.get_tensor(sO_layout.outer, swizzle=sO_layout.inner, dtype=self.dtype) block_info = BlockInfo( self.tile_m, self.tile_n, self.is_causal, self.is_local, False, # is_split_kv window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) SeqlenInfoCls = partial( SeqlenInfoQK.create, seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], seqlen_k_static=mK.shape[0] if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], mCuSeqlensQ=mCuSeqlensQ, mCuSeqlensK=mCuSeqlensK, mSeqUsedQ=mSeqUsedQ, mSeqUsedK=mSeqUsedK, # Don't need to pass in tile_mn because we won't access offset_padded ) AttentionMaskCls = partial( AttentionMask, self.tile_m, self.tile_n, window_size_left=window_size_left, window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) TileSchedulerCls = partial(TileScheduler.create, tile_sched_params) # Cluster wait before starting pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) if warp_idx < 4: # Producer cute.arch.setmaxregister_decrease(self.num_producer_regs) self.load( mQ, mK, mV, sQ, sK, sV, tma_atom_Q, tma_atom_K, tma_atom_V, pipeline_k, pipeline_v, pipeline_q, mPageTable, blocksparse_tensors, block_info, SeqlenInfoCls, TileSchedulerCls, ) else: # Consumer cute.arch.setmaxregister_increase(self.num_mma_regs) # /////////////////////////////////////////////////////////////////////////////// # Tile MMA compute thread partitions and allocate accumulators # /////////////////////////////////////////////////////////////////////////////// tidx, _, _ = cute.arch.thread_idx() tidx = tidx - 128 self.mma( tiled_mma_qk, tiled_mma_pv, mQ, mO, mLSE, sQ, sK, sVt, sP, sO, learnable_sink, pipeline_k, pipeline_v, pipeline_q, mbar_ptr_Q, gmem_tiled_copy_Q, gmem_tiled_copy_O, tma_atom_O, tidx, softmax_scale_log2, softmax_scale, block_info, SeqlenInfoCls, AttentionMaskCls, TileSchedulerCls, blocksparse_tensors, aux_tensors, fastdiv_mods, ) @cute.jit def load( self, mQ: cute.Tensor, mK: cute.Tensor, mV: cute.Tensor, sQ: cute.Tensor, sK: cute.Tensor, sV: cute.Tensor, tma_atom_Q: Optional[cute.CopyAtom], tma_atom_K: Optional[cute.CopyAtom], tma_atom_V: Optional[cute.CopyAtom], pipeline_k: pipeline.PipelineAsync, pipeline_v: pipeline.PipelineAsync, pipeline_q: Optional[pipeline.PipelineAsync], mPageTable: Optional[cute.Tensor], blocksparse_tensors: Optional[BlockSparseTensors], block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tidx, _, _ = cute.arch.thread_idx() # TMA: only warp 0 loads. cp_async: all warps load is_load_warp = warp_idx_in_wg == 0 or const_expr(not self.use_tma_KV) if is_load_warp: q_producer_phase = Int32(1) kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_stages ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # if work_tile.is_valid_tile: m_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = ( head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx ) load_Q = None if const_expr(self.use_tma_Q): gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) load_Q, _, _ = copy_utils.tma_get_copy_fn( tma_atom_Q, 0, cute.make_layout(1), gQ, sQ, single_stage=True ) paged_kv_manager = None tma_load_K_fn = None tma_load_V_fn = None if const_expr(self.use_tma_KV): # === TMA path (non-paged and paged with page_size == n_block_size) === if const_expr(mPageTable is not None): # Paged TMA: keep page dimension indexable mK_cur = mK[None, None, head_idx_kv, None] mV_cur = mV[None, None, head_idx_kv, None] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (0, 0, None)) gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (0, 0, None)) else: # Non-paged TMA mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[ None, None, head_idx_kv ] mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[ None, None, head_idx_kv ] gK = cute.local_tile(mK_cur, (self.tile_n, self.tile_hdim), (None, 0)) gV = cute.local_tile(mV_cur, (self.tile_n, self.tile_hdimv), (None, 0)) # TODO: mcast tma_load_K_fn, _, _ = copy_utils.tma_get_copy_fn( tma_atom_K, 0, cute.make_layout(1), gK, sK ) tma_load_K_fn = copy_utils.tma_producer_copy_fn(tma_load_K_fn, pipeline_k) tma_load_V_fn, _, _ = copy_utils.tma_get_copy_fn( tma_atom_V, 0, cute.make_layout(1), gV, sV ) tma_load_V_fn = copy_utils.tma_producer_copy_fn(tma_load_V_fn, pipeline_v) else: # === cp_async path (paged KV with page_size != n_block_size) === paged_kv_manager = PagedKVManager.create( mPageTable, mK, mV, FastDivmodDivisor(mK.shape[0]), batch_idx, head_idx_kv, tidx, seqlen.seqlen_k, 0, # leftpad_k self.tile_n, self.tile_hdim, self.tile_hdimv, self.num_threads_per_warp_group, mK.element_type, arch=self.arch.major * 10 + self.arch.minor, ) load_K = partial( self.load_KV, tma_load_K_fn, paged_kv_manager, sK, pipeline_kv=pipeline_k, K_or_V="K", ) load_V = partial( self.load_KV, tma_load_V_fn, paged_kv_manager, sV, pipeline_kv=pipeline_v, K_or_V="V", ) if const_expr(not self.use_block_sparsity): n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) # Clamp n_block to 0 when n_block_max == 0 (can happen with causal # + pack_gqa when seqlen_k < tile_n). TMA handles n_block=-1 # gracefully (fills zeros), but cp.async would crash on # out-of-bounds page table access. n_block = ( n_block_max - 1 if const_expr(self.use_tma_KV) else cutlass.max(n_block_max - 1, 0) ) page_idx = ( mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) # First iteration: load Q on pipeline_q, K on pipeline_k pipeline_k.producer_acquire(kv_producer_state) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) if const_expr(self.use_tma_Q): if warp_idx_in_wg == 0: pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) q_producer_phase ^= 1 if const_expr(not self.intra_wg_overlap or not self.use_tma_KV): pipeline_v.producer_acquire(kv_producer_state) load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) kv_producer_state.advance() for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block = n_block_max - 1 - i - 1 page_idx = ( mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) if const_expr(not self.use_tma_KV): paged_kv_manager.load_page_table(n_block) pipeline_k.producer_acquire(kv_producer_state) load_K( block=n_block, producer_state=kv_producer_state, page_idx=page_idx ) pipeline_v.producer_acquire(kv_producer_state) load_V( block=n_block, producer_state=kv_producer_state, page_idx=page_idx ) kv_producer_state.advance() else: for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): n_block_prev = n_block_max - i - 1 n_block = n_block_prev - 1 page_idx = ( mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None ) page_idx_prev = ( mPageTable[batch_idx, n_block_prev] if const_expr(mPageTable is not None) else None ) kv_producer_state_prev = kv_producer_state.clone() kv_producer_state.advance() pipeline_k.producer_acquire(kv_producer_state) load_K( block=n_block, producer_state=kv_producer_state, page_idx=page_idx ) pipeline_v.producer_acquire(kv_producer_state_prev) load_V( block=n_block_prev, producer_state=kv_producer_state_prev, page_idx=page_idx_prev, ) n_block = n_block_min page_idx = ( mPageTable[batch_idx, n_block] if const_expr(mPageTable is not None) else None ) pipeline_v.producer_acquire(kv_producer_state) load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) kv_producer_state.advance() else: # Block sparsity: use TMA closures directly (not paged) # Load Q on pipeline_q, separate from K/V pipeline if const_expr(self.use_tma_Q): if warp_idx_in_wg == 0: pipeline_q.producer_acquire_w_index_phase(0, q_producer_phase) load_Q(tma_bar_ptr=pipeline_q.sync_object_full.get_barrier(0)) q_producer_phase ^= 1 kv_producer_state = produce_block_sparse_loads( blocksparse_tensors, batch_idx, head_idx, m_block, kv_producer_state, tma_load_K_fn, tma_load_V_fn, pipeline_k, pipeline_v, self.intra_wg_overlap, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop # Producer tail is only useful for cluster to avoid early exit of blocks. # We only need producer_tail on V since that's the last that's loaded, we don't # need it for Q (no cluster) and K. pipeline_v.producer_tail(kv_producer_state) @cute.jit def load_KV( self, tma_load_fn: Optional[Callable], paged_kv_manager: Optional[PagedKVManager], sX: cute.Tensor, block: Int32, pipeline_kv: pipeline.PipelineAsync, producer_state: pipeline.PipelineState, K_or_V: Literal["K", "V"], page_idx: Optional[Int32] = None, ): if const_expr(self.use_tma_KV): src_idx = block if const_expr(page_idx is None) else page_idx tma_load_fn(src_idx=src_idx, producer_state=producer_state) else: paged_kv_manager.load_KV(block, sX[None, None, producer_state.index], K_or_V) cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(0) pipeline_kv.producer_commit(producer_state) @cute.jit def mma( self, tiled_mma_qk: cute.TiledMma, tiled_mma_pv: cute.TiledMma, # softmax: Softmax, # acc_O: cute.Tensor, mQ: cute.Tensor, mO: cute.Tensor, mLSE: Optional[cute.Tensor], sQ: cute.Tensor, sK: cute.Tensor, sVt: cute.Tensor, sP: Optional[cute.Tensor], sO: cute.Tensor, learnable_sink: Optional[cute.Tensor], pipeline_k: pipeline.PipelineAsync, pipeline_v: pipeline.PipelineAsync, pipeline_q: Optional[pipeline.PipelineAsync], mbar_ptr_Q: cutlass.Pointer, gmem_tiled_copy_Q: cute.TiledCopy, gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: Optional[cute.CopyAtom], tidx: Int32, softmax_scale_log2: Float32, softmax_scale: Optional[Float32], block_info: BlockInfo, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], aux_tensors: Optional[list], fastdiv_mods=None, ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( self.num_wg_mma, stride=self.num_threads_per_warp_group ) thr_mma_qk = tiled_mma_qk.get_slice(tidx) wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)) wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)) _, tSrQ, tSrK = sm90_utils.partition_fragment_ABC( wg_mma_qk, (self.tile_m, self.tile_n, self.tile_hdim), sQ, sK ) mma_qk_fn = partial( sm90_utils.gemm_zero_init, tiled_mma_qk, (self.tile_m, self.tile_n), tSrQ, tSrK ) acc_O, tOrP, tOrVt = sm90_utils.partition_fragment_ABC( wg_mma_pv, (self.tile_m, self.tile_hdimv, self.tile_n), sP, sVt ) mma_pv_fn = partial(sm90_utils.gemm_w_idx, tiled_mma_pv, acc_O, tOrP, tOrVt) # /////////////////////////////////////////////////////////////////////////////// # Smem copy atom tiling # /////////////////////////////////////////////////////////////////////////////// smem_copy_atom_P = utils.get_smem_store_atom( self.arch.major * 10 + self.arch.minor, self.dtype ) smem_thr_copy_P = cute.make_tiled_copy_C(smem_copy_atom_P, tiled_mma_qk).get_slice(tidx) tPsP = smem_thr_copy_P.partition_D(sP) if const_expr(sP is not None) else None smem_copy_params = SimpleNamespace(smem_thr_copy_P=smem_thr_copy_P, tPsP=tPsP) self.mma_init() mma_q_consumer_phase = Int32(0) q_consumer_phase = Int32(0) kv_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_stages ) tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() softmax = Softmax.create( softmax_scale_log2, num_rows=acc_O.shape[0][0] * acc_O.shape[1], softmax_scale=softmax_scale, ) # For RescaleOBeforeGemm: persistent scores_scale across iterations scores_scale = None if const_expr(self.rescale_O_before_gemm): scores_scale = cute.make_rmem_tensor_like(softmax.row_max, Float32) mma_one_n_block_all = partial( self.mma_one_n_block_intrawg_overlap if const_expr(self.intra_wg_overlap) else self.mma_one_n_block, mma_qk_fn=mma_qk_fn, pipeline_k=pipeline_k, pipeline_v=pipeline_v, acc_O=acc_O, tOrP=tOrP, smem_copy_params=smem_copy_params, check_inf=True, scores_scale=scores_scale, ) process_first_half_block = partial( self.first_half_block_overlap, mma_qk_fn=mma_qk_fn, pipeline_k=pipeline_k, tOrP=tOrP, smem_copy_params=smem_copy_params, scores_scale=scores_scale, softmax=softmax, acc_O=acc_O, ) process_last_half_block = partial( self.last_half_block_overlap, pipeline_v=pipeline_v, mma_pv_fn=mma_pv_fn, scores_scale=scores_scale, softmax=softmax, acc_O=acc_O, ) while work_tile.is_valid_tile: # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) m_block, head_idx, batch_idx, _ = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # Recompute fastdiv_mods if necessary for varlen with aux_tensors recompute_fastdiv_mods_q = cutlass.const_expr( aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) ) recompute_fastdiv_mods_k = cutlass.const_expr( aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k) ) if cutlass.const_expr(fastdiv_mods is not None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods fastdiv_mods = ( seqlen_q_divmod if not recompute_fastdiv_mods_q else FastDivmodDivisor(seqlen.seqlen_q), seqlen_k_divmod if not recompute_fastdiv_mods_k else FastDivmodDivisor(seqlen.seqlen_k), ) mask = AttentionMaskCls(seqlen) mask_fn = partial( mask.apply_mask, batch_idx=batch_idx, head_idx=head_idx, m_block=m_block, thr_mma=thr_mma_qk, mask_causal=self.is_causal, mask_local=self.is_local, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) score_mod_fn = None if const_expr(self.score_mod is not None): score_mod_fn = partial( self.apply_score_mod, thr_mma_qk, batch_idx, head_idx, m_block, softmax_scale=softmax_scale, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) mma_one_n_block = partial( mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn ) # Load Q if not TMA_Q if const_expr(not self.use_tma_Q): pack_gqa = PackGQA( self.tile_m, self.tile_hdim, self.check_hdim_oob, self.qhead_per_kvhead ) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] # gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx) # gQ = cute.local_tile(mQ_cur, (self.tile_m, self.tile_hdim), (m_block, 0)) # self.load_Q(gmem_thr_copy_Q, gQ, sQ, m_block, seqlen=seqlen.seqlen_q, # headdim=mQ.shape[1]) pack_gqa.load_Q(mQ_cur, sQ, gmem_tiled_copy_Q, tidx, m_block, seqlen.seqlen_q) cute.arch.cp_async_mbarrier_arrive_noinc(mbar_ptr_Q) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) if const_expr(self.use_tma_Q): pipeline_q.consumer_wait_w_index_phase(0, mma_q_consumer_phase) else: cute.arch.mbarrier_wait(mbar_ptr_Q, phase=q_consumer_phase) q_consumer_phase ^= 1 # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. # We need masking on S for the very last block when K and V has length not multiple of tile_n. # We also need masking on S if it's causal, for the last several blocks. # softmax.reset() # Don't need reset as we explicitly call softmax w is_first=True O_should_accumulate = False # ========================================== # MAINLOOP # ========================================== if const_expr(not self.use_block_sparsity): # ========================================== # No block-sparsity (original path) # ========================================== # First iteration with seqlen masking if const_expr(self.intra_wg_overlap): kv_consumer_state = process_first_half_block( n_block=n_block_max - 1, seqlen=seqlen, kv_consumer_state=kv_consumer_state, mask_fn=partial(mask_fn, mask_mod=self.mask_mod), score_mod_fn=score_mod_fn, is_first_block=True, ) else: self.warp_scheduler_barrier_sync() kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=True), is_first_n_block=True, mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True), ) O_should_accumulate = True # if cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block_max = {}, n_block_min = {}", m_block, n_block_max, n_block_min) n_block_max -= 1 # Next couple of iterations with causal masking if const_expr(self.is_causal or self.is_local): n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_causal_local_mask = {}", n_block_min_causal_local_mask) for n_tile in cutlass.range( n_block_max - n_block_min_causal_local_mask, unroll=1 ): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) # The remaining iterations have no masking n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( seqlen, m_block, n_block_min ) # if cute.arch.thread_idx()[0] == 128: cute.printf("n_block_min_before_local_mask = {}, n_block_min = {}", n_block_min_before_local_mask, n_block_min) for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True # Separate iterations with local masking on the left if const_expr(self.is_local and block_info.window_size_left is not None): n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) for n_tile in cutlass.range(n_block_max - n_block_min, unroll=1): kv_consumer_state = mma_one_n_block( kv_consumer_state, n_block=n_block_max - 1 - n_tile, seqlen=seqlen, mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate), mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False), ) O_should_accumulate = True # Last "half" iteration if const_expr(self.intra_wg_overlap): kv_consumer_state = process_last_half_block( kv_consumer_state=kv_consumer_state, zero_init=not O_should_accumulate, ) O_should_accumulate = True else: self.warp_scheduler_barrier_arrive() else: # ========================================== # Block sparsity # ========================================== kv_consumer_state, O_should_accumulate, processed_any = consume_block_sparse_loads( blocksparse_tensors, batch_idx, head_idx, m_block, seqlen, kv_consumer_state, mma_pv_fn, mma_one_n_block, process_first_half_block, process_last_half_block, mask_fn, score_mod_fn, O_should_accumulate, self.mask_mod, fastdiv_mods, self.intra_wg_overlap, self.warp_scheduler_barrier_sync, self.warp_scheduler_barrier_arrive, self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) # Handle empty case (when no blocks to process) if not processed_any: softmax.reset() acc_O.fill(0.0) sink_val = None if const_expr(learnable_sink is not None): if const_expr(not self.pack_gqa): sink_val = Float32(learnable_sink[head_idx]) else: # Each thread might have a different sink value due to different q_head sink_val = cute.make_rmem_tensor_like(softmax.row_max, Float32) cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS)) for r in cutlass.range(cute.size(sink_val), unroll_full=True): row = m_block * self.tile_m + tScS_mn[r][0] q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead sink_val[r] = Float32(learnable_sink[q_head_idx]) # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize(sink_val=sink_val) softmax.rescale_O(acc_O, row_scale) # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// self.epilogue( acc_O, softmax.row_sum, mO, mLSE, sO, seqlen, gmem_tiled_copy_O, tma_atom_O, tiled_mma_pv, tidx, m_block, head_idx, batch_idx, ) # Release Q pipeline so the producer can load the next tile's Q if const_expr(self.use_tma_Q): pipeline_q.consumer_release_w_index(0) mma_q_consumer_phase ^= 1 tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() @cute.jit def first_half_block_overlap( self, n_block: Int32, mma_qk_fn: Callable, kv_consumer_state, pipeline_k, tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, seqlen: SeqlenInfoQK, scores_scale: Optional[cute.Tensor] = None, acc_O: Optional[cute.Tensor] = None, mask_fn: Callable = None, score_mod_fn: Optional[Callable] = None, is_first_block: bool = False, ): """Processes the first half block when using intra-warpgroup-overlap""" pipeline_k.consumer_wait(kv_consumer_state, pipeline_k.consumer_try_wait(kv_consumer_state)) acc_S = mma_qk_fn(B_idx=kv_consumer_state.index, wg_wait=0) pipeline_k.consumer_release(kv_consumer_state) # Apply score modification if present if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) # Apply mask; mask_seqlen always True for first block # Caveat: if full block further right than mask block, seqlen masking is redundant; # however, masking is being applied anyway, so essentially no perf hit mask_fn(acc_S, n_block=n_block, mask_seqlen=True) row_scale = softmax.online_softmax(acc_S, is_first=is_first_block) tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) tOrP_cur = ( tOrP if const_expr(self.mma_pv_is_rs) else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) ) tOrP_cur.store(tOrP_acc.load().to(self.dtype)) if const_expr(not self.mma_pv_is_rs): tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) # Fence and barrier to make smem store visible to WGMMA cute.arch.fence_view_async_shared() cute.arch.sync_warp() # For RescaleOBeforeGemm: initialize acc_O if const_expr(self.rescale_O_before_gemm): acc_O.fill(0.0) scores_scale.store(row_scale.load()) return kv_consumer_state @cute.jit def last_half_block_overlap( self, kv_consumer_state, pipeline_v, mma_pv_fn: Callable, zero_init: bool, scores_scale: Optional[cute.Tensor] = None, softmax: Optional[Softmax] = None, acc_O: Optional[cute.Tensor] = None, ): """Processes the final PV GEMM when using intra-warpgroup-overlap""" # For RescaleOBeforeGemm: rescale O before the final PV GEMM if const_expr(self.rescale_O_before_gemm): softmax.rescale_O(acc_O, scores_scale) pipeline_v.consumer_wait(kv_consumer_state, pipeline_v.consumer_try_wait(kv_consumer_state)) mma_pv_fn(B_idx=kv_consumer_state.index, zero_init=zero_init, wg_wait=0) pipeline_v.consumer_release(kv_consumer_state) kv_consumer_state.advance() return kv_consumer_state @cute.jit def mma_one_n_block( self, smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, n_block: Int32, mma_qk_fn: Callable, mma_pv_fn: Callable, pipeline_k: pipeline.PipelineAsync, pipeline_v: pipeline.PipelineAsync, acc_O: cute.Tensor, tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, seqlen: SeqlenInfoQK, scores_scale: Optional[cute.Tensor] = None, # not used score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, is_first_n_block: cutlass.Constexpr = False, check_inf: cutlass.Constexpr = True, ): pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) # S = Q @ K.T acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(0) pipeline_k.consumer_release(smem_pipe_read) # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) if const_expr(mask_fn is not None): mask_fn(acc_S=acc_S, n_block=n_block) row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf) # if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) tOrP_cur = ( tOrP if const_expr(self.mma_pv_is_rs) else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) ) # tOrP.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of # 2 elements. So we just call ptx directly. utils.cvt_f16(tOrP_acc, tOrP_cur) if const_expr(not self.mma_pv_is_rs): tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) softmax.rescale_O(acc_O, row_scale) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA cute.arch.fence_view_async_shared() cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() # O += P @ V mma_pv_fn(B_idx=smem_pipe_read.index, wg_wait=0) pipeline_v.consumer_release(smem_pipe_read) smem_pipe_read.advance() return smem_pipe_read @cute.jit def mma_one_n_block_intrawg_overlap( self, smem_pipe_read: pipeline.PipelineState | pipeline_custom.PipelineStateSimple, n_block: Int32, mma_qk_fn: Callable, mma_pv_fn: Callable, pipeline_k: pipeline.PipelineAsync, pipeline_v: pipeline.PipelineAsync, acc_O: cute.Tensor, tOrP: cute.Tensor, smem_copy_params: SimpleNamespace, softmax: Softmax, seqlen: SeqlenInfoQK, scores_scale: Optional[cute.Tensor] = None, score_mod_fn: Optional[Callable] = None, mask_fn: Optional[Callable] = None, check_inf: cutlass.Constexpr = True, ): smem_pipe_read_v = smem_pipe_read.clone() smem_pipe_read.advance() pipeline_k.consumer_wait(smem_pipe_read, pipeline_k.consumer_try_wait(smem_pipe_read)) self.warp_scheduler_barrier_sync() # S = Q @ K.T acc_S = mma_qk_fn(B_idx=smem_pipe_read.index, wg_wait=-1) # RescaleOBeforeGemm: rescale O while QK GEMM is in flight, before PV GEMM if const_expr(self.rescale_O_before_gemm): softmax.rescale_O(acc_O, scores_scale) pipeline_v.consumer_wait(smem_pipe_read_v, pipeline_v.consumer_try_wait(smem_pipe_read_v)) # O += P @ V mma_pv_fn(B_idx=smem_pipe_read_v.index, wg_wait=-1) self.warp_scheduler_barrier_arrive() warpgroup.wait_group(1) pipeline_k.consumer_release(smem_pipe_read) # handle score mods and masking if const_expr(score_mod_fn is not None): score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen) if const_expr(mask_fn is not None): mask_fn(acc_S=acc_S, n_block=n_block) # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S)) row_scale = softmax.online_softmax(acc_S, check_inf=check_inf) warpgroup.wait_group(0) pipeline_v.consumer_release(smem_pipe_read_v) tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S) tOrP_cur = ( tOrP if const_expr(self.mma_pv_is_rs) else cute.make_rmem_tensor_like(tOrP_acc, self.dtype) ) # tOrP_cur.store(tOrP_acc.load().to(self.dtype)) # the "to(self.dtype)" conversion fails to vectorize for block sizes other # than 128 x 128, i.e. it calls convert on 1 fp32 element at a time instead of # 2 elements. So we just call ptx directly. utils.cvt_f16(tOrP_acc, tOrP_cur) if const_expr(not self.mma_pv_is_rs): tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur) cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP) if const_expr(not self.rescale_O_before_gemm): softmax.rescale_O(acc_O, row_scale) if const_expr(self.rescale_O_before_gemm): scores_scale.store(row_scale.load()) if const_expr(not self.mma_pv_is_rs): # Fence and barrier to make sure smem store is visible to WGMMA cute.arch.fence_view_async_shared() cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV return smem_pipe_read @cute.jit def mma_init(self): warp_group_idx = utils.canonical_warp_group_idx(sync=False) if const_expr(self.use_scheduler_barrier): if warp_group_idx == 1: cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1), number_of_threads=2 * self.num_threads_per_warp_group, ) @cute.jit def apply_score_mod( self, thr_mma_qk, batch_idx, head_idx, m_block, acc_S, n_block, softmax_scale, seqlen, aux_tensors: Optional[list] = None, fastdiv_mods=None, ): # Prepare index tensor cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) tScS = thr_mma_qk.partition_C(cS) apply_score_mod_inner( acc_S, tScS, self.score_mod, batch_idx, head_idx, softmax_scale, self.vec_size, self.qk_acc_dtype, aux_tensors, fastdiv_mods, seqlen_info=seqlen, constant_q_idx=None, qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) def warp_scheduler_barrier_sync(self): if const_expr(self.use_scheduler_barrier): cute.arch.barrier( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) - 1 + utils.canonical_warp_group_idx(sync=False), number_of_threads=2 * self.num_threads_per_warp_group, ) def warp_scheduler_barrier_arrive(self): if const_expr(self.use_scheduler_barrier): assert self.num_wg_mma in [2, 3] cur_wg = utils.canonical_warp_group_idx(sync=False) - 1 if const_expr(self.num_wg_mma == 2): next_wg = 1 - cur_wg else: t = cur_wg + 1 next_wg = t % self.num_wg_mma cute.arch.barrier_arrive( barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg, number_of_threads=2 * self.num_threads_per_warp_group, ) ================================================ FILE: flash_attn/cute/interface.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. # Supported features: # - BF16 & FP16 dtype # - noncausal & causal attention # - MHA, GQA, MQA # - hdim 64, 96, 128. # - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) # - varlen # - sliding window # - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) # Features not supported yet: # - split (i.e. FlashDecoding) # - tuned block sizes # - paged KV # - append KV to existing KV cache # - FP8 # - bwd pass optimized for Hopper/Blackwell import os import math from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, Callable import torch import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: from flash_attn.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() from flash_attn.cute import utils from flash_attn.cute import fa_logging from flash_attn.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, ) def _parse_arch_str(arch_str): """Parse arch string (e.g. 'sm_80', 'sm_90a', '80', '100') to int (e.g. 80, 90, 100).""" import re match = re.match(r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$", arch_str) if not match: raise ValueError(f"Invalid arch format: {arch_str}") major, minor, _ = match.groups() return int(major) * 10 + int(minor) @lru_cache(maxsize=None) def _get_device_arch(): """Cached device arch check. Override with FLASH_ATTENTION_ARCH (e.g. 'sm_80' or '80') to select which kernel path to use (SM80/SM90/SM100/SM120) independently of the compilation target (CUTE_DSL_ARCH). For CPU-only compilation (no GPU), set both: FLASH_ATTENTION_ARCH=sm_80 (kernel selection) CUTE_DSL_ARCH=sm_80 (compilation target) """ arch_override = os.environ.get("FLASH_ATTENTION_ARCH", None) if arch_override is not None: return _parse_arch_str(arch_override) major, minor = torch.cuda.get_device_capability() return major * 10 + int(minor) def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 if compute_capability == 9: assert is_sm90_range and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. " f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." ) @dataclass(frozen=True) class FwdConfig: m_block_size: int n_block_size: int mma_pv_is_rs: bool intra_wg_overlap: bool def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_block_sparsity): """Return FwdConfig for SM90 forward. Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM). """ if head_dim <= 64: # C++: 192×192 non-causal, 192×128 causal/local. # Python: 192×128 RS+OL is consistently best across seqlens. return FwdConfig(192, 128, True, True) elif head_dim <= 96: # C++: 192×144 noRS+OL for all cases. # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS). # noRS+OL is always required. Causal: 192×128 slightly better short seqlen. if is_causal or is_local: return FwdConfig(192, 128, False, True) else: return FwdConfig(192, 144, False, True) elif head_dim <= 128: return FwdConfig(128, 128, True, True) elif head_dim <= 192: tile_n = 96 if is_local else (128 if head_dim_v <= 128 else 112) return FwdConfig(128, tile_n, True, True) else: # hdim 256 tile_n = 64 if is_local else 80 return FwdConfig(128, tile_n, True, True) @dataclass(frozen=True) class BwdConfig: m_block_size: int n_block_size: int num_stages_Q: int num_stages_dO: int num_stages_PdS: int SdP_swapAB: bool dKV_swapAB: bool dQ_swapAB: bool AtomLayoutMSdP: int AtomLayoutNdKV: int AtomLayoutMdQ: int num_wg: int = 2 # MMA warp groups (total threads = (num_wg + 1) * 128) dQ_single_wg: bool = False def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local): """Return BwdConfig for SM90. Configs based on C++ FA3 hopper/flash_bwd_launch_template.h, benchmarked on H100 SXM. """ if head_dim <= 64: # C++ FA3: 128, 128, 64, ..., 2, 2, true, false, false, 2, 1, 2, 2 return BwdConfig( m_block_size=128, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=2, ) elif head_dim <= 96: # C++ FA3: 64, 128, 96, dQ_swapAB=False return BwdConfig( m_block_size=64, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, dQ_single_wg=True, ) elif head_dim <= 128: # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB is_causal_or_local = causal or local m_block_size = 64 if is_causal_or_local else 80 return BwdConfig( m_block_size=m_block_size, n_block_size=128, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=2, SdP_swapAB=True, dKV_swapAB=False, dQ_swapAB=m_block_size % 64 != 0, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, ) elif head_dim <= 192: hdimv128 = head_dim_v <= 128 if hdimv128: return BwdConfig( m_block_size=64, n_block_size=96, num_stages_Q=2, num_stages_dO=2, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, num_wg=2, ) else: return BwdConfig( m_block_size=64, n_block_size=96, num_stages_Q=2, num_stages_dO=1, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=True, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=2, AtomLayoutMdQ=1, num_wg=2, ) else: # hdim 256 return BwdConfig( m_block_size=64, n_block_size=64, num_stages_Q=1, num_stages_dO=1, num_stages_PdS=1, SdP_swapAB=False, dKV_swapAB=False, dQ_swapAB=False, AtomLayoutMSdP=1, AtomLayoutNdKV=1, AtomLayoutMdQ=1, ) def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): assert t.shape == expected_shape, f"{name} shape {t.shape} != expected {expected_shape}" assert t.dtype == expected_dtype, f"{name} dtype {t.dtype} != expected {expected_dtype}" assert t.device == expected_device, f"{name} device {t.device} != expected {expected_device}" if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA" torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, } def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): # If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if num_n_blocks <= 4: return 1 # NOTE: We should revisit this heuristic after persistence is supported for split KV. # Sometimes, it's ideal to over-schedule splits for better efficiency. return min(num_SMs // total_mblocks, max_splits, num_n_blocks) def _resolve_causal_local_window(causal, window_size_left, window_size_right, mask_mod=None): """Resolve causal/local/window settings into canonical form. Returns (causal, local, window_size_left, window_size_right). """ if mask_mod is not None: return False, False, window_size_left, window_size_right if causal: window_size_right = 0 if window_size_left is not None and window_size_right is not None and window_size_left + window_size_right < 0: window_size_left = None window_size_right = None if window_size_left is not None or window_size_right is not None: if window_size_left is None and window_size_right == 0: causal, local = True, False window_size_right = None else: causal, local = False, True else: local = False return causal, local, window_size_left, window_size_right def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, learnable_sink: Optional[torch.Tensor] = None, tile_mn: Optional[Tuple[int, int]] = None, mma_pv_is_rs: Optional[bool] = None, intra_wg_overlap: Optional[bool] = None, num_threads: int = 384, num_splits: int = 1, pack_gqa: Optional[bool] = None, _arch: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. Args: ... score_mod: A callable that takes the attention scores and applies a modification. mask_mod: A callable that takes token position information and selectively masks block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate The returned LSE supports taking gradient. out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 seqlen_q = None total_q = q.shape[0] if page_table is not None: assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" assert page_table.dtype == torch.int32, "page_table must be int32" assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" max_num_pages_per_seq = page_table.shape[1] assert page_table.shape == (batch_size, max_num_pages_per_seq) num_pages, page_size = k.shape[:2] seqlen_k = num_pages * page_size else: num_pages, page_size = None, None seqlen_k = k.shape[-3] num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] if cu_seqlens_k is None: if page_table is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: assert k.shape == (num_pages, page_size, num_head_kv, head_dim) assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), ( "cu_seqlens_k must have shape (batch_size + 1,)" ) if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), ( "cu_seqlens_q must have shape (batch_size + 1,)" ) assert seqused_q is None or seqused_q.shape == (batch_size,), ( "seqused_q must have shape (batch_size,)" ) assert seqused_k is None or seqused_k.shape == (batch_size,), ( "seqused_k must have shape (batch_size,)" ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: assert t.dtype == torch.int32, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" ) assert t.stride(0) == 1, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" ) if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" if not is_fake_mode(): assert all( t is None or t.is_cuda for t in ( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink, ) ), "inputs must be on CUDA device" arch = _get_device_arch() if _arch is None else _arch assert arch // 10 in [8, 9, 10, 11, 12], "Unsupported compute capability. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" alignment = 16 // q.element_size() if arch // 10 not in [8, 12]: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 out_torch_dtype = q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) requires_grad = q.requires_grad or k.requires_grad or v.requires_grad if out is None: out = torch.empty( *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device ) else: _validate_tensor(out, "out", (*q_batch_seqlen_shape, num_head, head_dim_v), out_torch_dtype, device) if lse is None: lse = ( torch.empty(lse_shape, dtype=torch.float32, device=device) if requires_grad or return_lse else None ) elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) dtype = torch2cute_dtype_map[q.dtype] use_block_sparsity = block_sparse_tensors is not None causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right, mask_mod ) # In fake mode (CPU-only compilation), use a fake stream placeholder. current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) if arch // 10 in [8, 12]: num_threads = 128 fwd_cfg = FwdConfig(128, 128, True, True) # default if tile_mn is None: if arch // 10 == 12: # SM120 tile sizes tuned for 99 KB SMEM capacity: # D<=64: 128x128 → 48 KB (good occupancy) # D>64: 128x64 → 64 KB (128x128 would use 96 KB, hurting occupancy) if head_dim <= 64: fwd_cfg = FwdConfig(128, 128, True, True) else: fwd_cfg = FwdConfig(128, 64, True, True) elif arch // 10 == 8: fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune elif arch // 10 == 9: fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, use_block_sparsity) else: fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size if mma_pv_is_rs is None: mma_pv_is_rs = fwd_cfg.mma_pv_is_rs if intra_wg_overlap is None: intra_wg_overlap = fwd_cfg.intra_wg_overlap # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False if arch // 10 in [10, 11]: if pack_gqa and (128 % qhead_per_kvhead != 0): pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 else: q_stage = 1 m_block_size_effective = q_stage * tile_m seqlen_k_loaded = max_seqlen_k if not local else max(0, min(max_seqlen_k, (window_size_right or max_seqlen_k) + (window_size_left or max_seqlen_k) + 1 + tile_m)) num_m_blocks = (seqlen_q_packgqa + m_block_size_effective - 1) // m_block_size_effective total_mblocks = batch_size * num_head_kv * num_m_blocks num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_SMs = 132 if is_fake_mode() else torch.cuda.get_device_properties(device).multi_processor_count if num_splits < 1: num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: if num_n_blocks >= 64: tile_n = 64 num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) else: num_splits = 1 is_split_kv = num_splits > 1 if is_split_kv: out_partial = torch.empty(num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, dtype=torch.float32, device=device) lse_partial = torch.empty(num_splits, *lse_shape, dtype=torch.float32, device=device) use_2cta_instrs = ( arch // 10 in [10, 11] and not causal and not local and not is_split_kv and cu_seqlens_q is None and seqused_q is None and not use_block_sparsity and page_size in [None, 128] and int(math.ceil(head_dim / 16) * 16) == 128 and int(math.ceil(head_dim_v / 16) * 16) == 128 and seqlen_q_packgqa > 2 * tile_m ) # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) if mask_mod is not None: if is_varlen: raise NotImplementedError( "mask_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR." ) if use_block_sparsity: if is_varlen: raise NotImplementedError( "Block sparsity is not yet supported for varlen sequences. This will be fixed in a future PR." ) # NB: pack_gqa requires block sparse head dim == 1 (broadcasted) if pack_gqa and block_sparse_tensors.mask_block_cnt.shape[1] != 1: pack_gqa = False if is_split_kv: raise NotImplementedError( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." ) # See get_broadcast_dims for why this is needed in compile key block_sparse_broadcast_pattern = None normalized_block_sparse_tensors = None q_subtile_factor = None if block_sparse_tensors is not None: if seqlen_q is None: raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, q_subtile_factor, ) = normalize_block_sparse_config( block_sparse_tensors, batch_size=batch_size, num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, block_size=(tile_m, tile_n), q_stage=q_stage, ) if aux_tensors is not None: aux_tensor_metadata = get_aux_tensor_metadata(aux_tensors) else: aux_tensor_metadata = None compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, score_mod_hash, mask_mod_hash, use_block_sparsity, block_sparse_broadcast_pattern, aux_tensor_metadata, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, window_size_left is not None, window_size_right is not None, learnable_sink is not None, tile_m, tile_n, q_stage, num_threads, is_split_kv, pack_gqa, arch, page_size not in [None, tile_n], # paged KV non-TMA use_2cta_instrs, q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: ( cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor, ) = [ to_cute_tensor(t, assumed_align=4, leading_dim=0) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] page_table_tensor = ( to_cute_tensor(page_table, assumed_align=4, leading_dim=1) if page_table is not None else None ) q_tensor, k_tensor, v_tensor, o_tensor = [ to_cute_tensor(t) for t in (q, k, v, out if not is_split_kv else out_partial) ] if is_split_kv: lse_tensor = to_cute_tensor(lse_partial, assumed_align=4) elif lse is not None: lse_tensor = to_cute_tensor(lse, assumed_align=4) else: lse_tensor = None sparse_tensors = None if normalized_block_sparse_tensors is not None: sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None aux_tensor_metadata = None if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] if arch // 10 == 8: assert page_table is None, "paged KV not supported on SM 8.0" assert not is_split_kv, "SplitKV not supported on SM 8.0" fa_fwd = FlashAttentionForwardSm80( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, num_stages=1, num_threads=num_threads, Q_in_regs=False, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, ) elif arch // 10 == 9: assert not is_split_kv, "SplitKV not supported on SM 9.0" fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, intra_wg_overlap=intra_wg_overlap, mma_pv_is_rs=mma_pv_is_rs, mask_mod=mask_mod, score_mod=score_mod, has_aux_tensors=aux_tensors is not None, q_subtile_factor=q_subtile_factor, paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, m_block_size=tile_m, n_block_size=tile_n, q_stage=q_stage, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None and not is_split_kv, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, tile_n], is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" assert page_table is None, "Paged KV not supported on SM 12.0 in this PR" assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR" fa_fwd = FlashAttentionForwardSm120( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, num_stages=1, num_threads=num_threads, Q_in_regs=False, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, ) else: raise ValueError( f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, window_size_left, window_size_right, learnable_sink_tensor, sparse_tensors, cute_aux_tensors, current_stream, options="--enable-tvm-ffi", ) # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: # - Use those fake metadata to populate compilation cache # - Return "fake" output tensors, which could be needed in follow-up fake operations # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): _flash_attn_fwd.compile_cache[compile_key]( q.detach(), k.detach(), v.detach(), out.detach() if not is_split_kv else out_partial, lse_partial if is_split_kv else lse, softmax_scale, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, window_size_left, window_size_right, learnable_sink, normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, aux_tensors, ) if is_split_kv: _flash_attn_fwd_combine( out_partial, lse_partial.transpose(-1, -2), out, lse.transpose(-1, -2) if lse is not None else None, cu_seqlens_q, seqused_q, ) return out, lse _flash_attn_fwd.compile_cache = get_jit_cache("fwd") def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): sym = cute.sym_int # divisibility in elements: assumed_align_bytes = divisibility * dtype.width // 8 # For 16-byte align: fp16/bf16 → divisibility=8, float32 → divisibility=4 div = 128 // dtype.width # 8 for fp16/bf16 # Shared sym_ints for dimensions that must match across tensors b, seqlen_q, seqlen_k, h_q, d, d_v = sym(), sym(), sym(), sym(), sym(), sym() h_kv = h_q if not has_gqa else sym() seqlen_q_rounded, seqlen_k_rounded = sym(), sym() seqlen_q_d_rounded, seqlen_k_d_rounded, seqlen_k_dv_rounded = sym(), sym(), sym() total_q, total_k, total_q_rounded, total_k_rounded = sym(), sym(), sym(), sym() total_q_d_rounded, total_k_d_rounded, total_k_dv_rounded = sym(), sym(), sym() b_seqlenq = (b, seqlen_q) if not varlen_q else (total_q,) b_seqlenk = (b, seqlen_k) if not varlen_k else (total_k,) mQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) mO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) mdO = fake_tensor(dtype, (*b_seqlenq, h_q, d_v), divisibility=div) mK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) mV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) mdQ = fake_tensor(dtype, (*b_seqlenq, h_q, d), divisibility=div) mdK = fake_tensor(dtype, (*b_seqlenk, h_kv, d), divisibility=div) mdV = fake_tensor(dtype, (*b_seqlenk, h_kv, d_v), divisibility=div) if not varlen_q: mLSE = fake_tensor(Float32, (b, h_q, seqlen_q), divisibility=1) mLSElog2 = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) mPdPsum = fake_tensor(Float32, (b, h_q, seqlen_q_rounded), divisibility=4) dQaccum = fake_tensor(Float32, (b, h_q, seqlen_q_d_rounded), divisibility=4) else: mLSE = fake_tensor(Float32, (h_q, total_q), divisibility=1) mLSElog2 = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) mPdPsum = fake_tensor(Float32, (h_q, total_q_rounded), divisibility=4) dQaccum = fake_tensor(Float32, (h_q, total_q_d_rounded), divisibility=4) if not has_gqa: mdKaccum, mdVaccum = None, None else: if not varlen_k: mdKaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_rounded), divisibility=4) mdVaccum = fake_tensor(Float32, (b, h_kv, seqlen_k_dv_rounded), divisibility=4) else: mdKaccum = fake_tensor(Float32, (h_kv, total_k_rounded), divisibility=4) mdVaccum = fake_tensor(Float32, (h_kv, total_k_dv_rounded), divisibility=4) return mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, dQaccum, mdKaccum, mdVaccum def _compile_bwd_preprocess( dtype, head_dim, head_dim_v, m_block_size, has_cuseqlens_q, has_seqused_q, has_dlse, ): """Compile bwd preprocess kernel using cute fake tensors (no real GPU tensors needed).""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False ) batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSequsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None mdLSE = fake_tensor(Float32, mLSE.shape, divisibility=1) if has_dlse else None fa_bwd_pre = FlashAttentionBackwardPreprocess(dtype, head_dim, head_dim_v, m_block_size) return cute.compile( fa_bwd_pre, mO, mdO, mPdPsum, mLSE, mLSElog2, mdQaccum, mCuSeqlensQ, mSequsedQ, mdLSE, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, ): """Backward preprocess: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum.""" is_varlen = cu_seqlens_q is not None compile_key = ( dtype, head_dim, head_dim_v, m_block_size, is_varlen, seqused_q is not None, dlse is not None, ) if compile_key not in _bwd_preprocess.compile_cache: _bwd_preprocess.compile_cache[compile_key] = _compile_bwd_preprocess(*compile_key) if not is_fake_mode(): _bwd_preprocess.compile_cache[compile_key]( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse ) _bwd_preprocess.compile_cache = get_jit_cache("bwd_pre") def _compile_bwd_postprocess( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, has_cuseqlens_q, has_seqused_q, use_2cta_instrs, cluster_size, arch, ): """Compile bwd postprocess kernel using cute fake tensors.""" mQ, mK, mV, mO, mdO, mdQ, mdK, mdV, mLSE, mLSElog2, mPdPsum, mdQaccum, mdKaccum, mdVaccum = make_fake_bwd_tensors( dtype, has_gqa=True, varlen_q=has_cuseqlens_q, varlen_k=False ) batch = mQ.shape[0] if not has_cuseqlens_q else cute.sym_int() batchp1 = cute.sym_int() mCuSeqlensQ = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cuseqlens_q else None mSeqUsedQ = fake_tensor(Int32, (batch,), divisibility=1) if has_seqused_q else None fa_bwd_post = FlashAttentionBackwardPostprocess( dtype, hdim, arch, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=use_2cta_instrs, cluster_size=cluster_size, ) return cute.compile( fa_bwd_post, mdQaccum, mdQ, Float32(0.0), mCuSeqlensQ, mSeqUsedQ, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _bwd_postprocess_convert( accum, output, scale, cu_seqlens, seqused, arch, dtype, hdim, block_size, num_threads, atom_layout, swap_ab, use_2cta_instrs=False, cluster_size=1, ): """Backward postprocess: convert float32 accumulator to bf16/fp16 output.""" compile_key = ( dtype, hdim, block_size, num_threads, atom_layout, swap_ab, cu_seqlens is not None, seqused is not None, use_2cta_instrs, cluster_size, arch, ) if compile_key not in _bwd_postprocess_convert.compile_cache: _bwd_postprocess_convert.compile_cache[compile_key] = _compile_bwd_postprocess(*compile_key) if not is_fake_mode(): _bwd_postprocess_convert.compile_cache[compile_key]( accum, output, scale, cu_seqlens, seqused, ) _bwd_postprocess_convert.compile_cache = get_jit_cache("bwd_post") def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, dout: torch.Tensor, lse: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, m_block_size: int = 64, n_block_size: int = 128, num_threads: int = 256, pack_gqa: bool = False, num_stages_Q: int = 2, num_stages_dO: int = 2, SdP_swapAB: bool = False, dKV_swapAB: bool = False, dQ_swapAB: bool = False, AtomLayoutMSdP: int = 2, AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 2, V_in_regs: bool = False, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, deterministic: bool = False, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, mask_mod: Optional[Callable] = None, aux_tensors: Optional[list[torch.Tensor]] = None, block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, dlse: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: arch = _get_device_arch() assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" num_head, head_dim = q.shape[-2:] head_dim_v = v.shape[-1] causal, local, window_size_left, window_size_right = _resolve_causal_local_window( causal, window_size_left, window_size_right ) if arch // 10 == 12: # SM120: uses SM80 MMA with 99 KB SMEM, 128 threads (4 warps). m_block_size = 64 n_block_size = 64 if head_dim <= 64: num_stages_Q = 2 num_stages_dO = 2 else: num_stages_Q = 1 num_stages_dO = 1 SdP_swapAB = False dKV_swapAB = False dQ_swapAB = False AtomLayoutMSdP = 4 AtomLayoutNdKV = 4 AtomLayoutMdQ = 4 V_in_regs = False cluster_size = 1 use_2cta_instrs = False num_threads = 128 assert not (block_sparse_tensors is not None), "Block sparsity backward not supported on SM 12.0" assert score_mod is None and score_mod_bwd is None, "score_mod backward not supported on SM 12.0" assert mask_mod is None, "mask_mod backward not supported on SM 12.0" assert deterministic is False, "deterministic backward not supported on SM 12.0" elif arch // 10 == 9: cfg = _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local) m_block_size = cfg.m_block_size n_block_size = cfg.n_block_size num_stages_Q = cfg.num_stages_Q num_stages_dO = cfg.num_stages_dO num_stages_PdS = cfg.num_stages_PdS SdP_swapAB = cfg.SdP_swapAB dKV_swapAB = cfg.dKV_swapAB dQ_swapAB = cfg.dQ_swapAB AtomLayoutMSdP = cfg.AtomLayoutMSdP AtomLayoutNdKV = cfg.AtomLayoutNdKV AtomLayoutMdQ = cfg.AtomLayoutMdQ num_threads = (cfg.num_wg + 1) * 128 dQ_single_wg = cfg.dQ_single_wg cluster_size = 1 use_2cta_instrs = False is_varlen = ( cu_seqlens_q is not None or cu_seqlens_k is not None or seqused_q is not None or seqused_k is not None ) else: m_block_size = 128 n_block_size = 128 dQ_swapAB = False dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 disable_2cta = ( score_mod is not None or score_mod_bwd is not None or mask_mod is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 total_q = q.shape[0] seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q if cu_seqlens_k is None: batch_size, seqlen_k = k.shape[:2] total_k = batch_size * seqlen_k else: batch_size = cu_seqlens_k.shape[0] - 1 total_k = k.shape[0] seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k num_head_kv = k.shape[-2] use_block_sparsity = block_sparse_tensors is not None # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, # the base block_m of 128 from forward, and block-sparse size for subtiling. if arch // 10 == 9 and use_block_sparsity: m_block_size = 64 # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) dQ_swapAB = False # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 subtile_factor = 2 seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size num_n_blocks = seqlen_k_rounded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: seqlen_k_rounded = seqlen_k_rounded + n_block_size if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: assert k.shape == (total_k, num_head_kv, head_dim) assert v.shape == (total_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == (batch_size + 1,), ( "cu_seqlens_k must have shape (batch_size + 1,)" ) if cu_seqlens_q is not None: assert cu_seqlens_q.shape == (batch_size + 1,), ( "cu_seqlens_q must have shape (batch_size + 1,)" ) assert out.shape == (total_q, num_head, head_dim_v) assert dout.shape == (total_q, num_head, head_dim_v) assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" else: assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) assert lse.shape == (batch_size, num_head, seqlen_q), ( "lse must have shape (batch_size, num_head, seqlen_q)" ) assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( "inputs must have the same dtype" ) for t in [cu_seqlens_q, cu_seqlens_k]: if t is not None: assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" assert lse.dtype == torch.float32, "lse must be float32" if dlse is not None: dlse = maybe_contiguous(dlse) if not is_fake_mode(): assert all( t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" alignment = 16 // q.element_size() if arch // 10 != 12: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False if score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) device = q.device out_torch_dtype = q.dtype if dq is None: dq = torch.empty_like(q) else: _validate_tensor(dq, "dq", q.shape, out_torch_dtype, device) if dk is None: dk = torch.empty_like(k) else: _validate_tensor(dk, "dk", k.shape, out_torch_dtype, device) if dv is None: dv = torch.empty_like(v) else: _validate_tensor(dv, "dv", v.shape, out_torch_dtype, device) head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: dq_accum = torch.empty( batch_size, num_head, seqlen_q_rounded * head_dim_rounded, dtype=torch.float32, device=device, ) dpsum = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device ) lse_log2 = torch.empty( batch_size, num_head, seqlen_q_rounded, dtype=torch.float32, device=device ) else: total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) dq_accum = torch.empty( num_head, total_q_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device ) dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device) # GQA (qhead_per_kvhead > 1) needs dK/dV accum+postprocess since multiple Q heads # accumulate into the same dK/dV. SM90 varlen_k with qhead_per_kvhead==1 now uses # ragged TMA tensors for direct store, so no longer needs accum+postprocess. dKV_postprocess = qhead_per_kvhead > 1 if dKV_postprocess: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 if cu_seqlens_k is None: dk_accum = torch.zeros( batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, dtype=torch.float32, device=device, ) dv_accum = torch.zeros( batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, dtype=torch.float32, device=device, ) else: cluster_tile_n = cluster_size * n_block_size total_k_rounded_padded = ( (total_k + cu_seqlens_k.shape[0] * cluster_tile_n - 1) // cluster_tile_n * cluster_tile_n ) dk_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_rounded, dtype=torch.float32, device=device, ) dv_accum = torch.zeros( num_head_kv, total_k_rounded_padded * head_dim_v_rounded, dtype=torch.float32, device=device, ) dtype = torch2cute_dtype_map[q.dtype] current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) if deterministic: dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, cluster_size, dtype=torch.int32, device=device) else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: dK_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) dV_semaphore = torch.zeros(batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, dtype=torch.int32, device=device) else: dK_semaphore = None dV_semaphore = None # Preprocess kernel: compute (o * dout).sum(dim=-1) - dLSE, lse * log2_e, and zero out dq_accum. _bwd_preprocess( out, dout, dpsum, lse, lse_log2, dq_accum, cu_seqlens_q, seqused_q, dlse, dtype, head_dim, head_dim_v, m_block_size, ) # num_threads: SM90 derives from BwdConfig.num_wg, SM120 is set to 128 above, # SM100/SM110 uses default from function signature (384). if arch // 10 not in [9, 12]: num_threads = 384 # Backward kernel: compute dk, dv, dq_accum. score_mod_hash = utils.hash_callable(score_mod) if score_mod else False score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False num_aux_tensors = len(aux_tensors) if aux_tensors else 0 cute_aux_tensors = None if aux_tensors is not None: cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] block_sparse_broadcast_pattern = None normalized_block_sparse_tensors = None if block_sparse_tensors is not None: ( normalized_block_sparse_tensors, block_sparse_broadcast_pattern, ) = normalize_block_sparse_config_bwd( block_sparse_tensors, batch_size=batch_size, num_head=num_head, seqlen_q=seqlen_q, seqlen_k=seqlen_k, block_size=(m_block_size, n_block_size), subtile_factor=subtile_factor, ) if arch // 10 in [8, 9, 12]: compile_key = ( arch, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, window_size_left is not None, window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, num_threads, pack_gqa, num_stages_Q, num_stages_dO, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs, dQ_single_wg, deterministic, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, num_aux_tensors, use_block_sparsity, block_sparse_broadcast_pattern, get_broadcast_dims(q), get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), ) else: compile_key = ( arch, dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, window_size_left is not None, window_size_right is not None, softcap != 0.0, m_block_size, n_block_size, num_threads, pack_gqa, cluster_size, use_2cta_instrs, deterministic, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, num_aux_tensors, use_block_sparsity, block_sparse_broadcast_pattern, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, get_broadcast_dims(q), get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ to_cute_tensor(t) for t in (q, k, v, dout, dq, dk, dv) ] dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2) ] if dKV_postprocess: dk_accum_tensor, dv_accum_tensor = [ to_cute_tensor(t) for t in (dk_accum, dv_accum) ] cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ to_cute_tensor(t, assumed_align=4) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ utils.convert_from_dlpack_leading_static(t.detach(), leading_dim=3, alignment=4, stride_order=t.dim_order()) if t is not None else None for t in (dQ_semaphore, dK_semaphore, dV_semaphore) ] if arch // 10 in [8, 12]: flash_bwd_obj_cls = FlashAttentionBackwardSm120 if arch // 10 == 12 else FlashAttentionBackwardSm80 fa_bwd_obj = flash_bwd_obj_cls( dtype, head_dim, head_dim_v, qhead_per_kvhead, m_block_size, n_block_size, num_stages_Q, num_stages_dO, num_threads, pack_gqa, causal, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, is_local=local, deterministic=deterministic, tile_m=m_block_size, tile_n=n_block_size, Q_stage=num_stages_Q, dO_stage=num_stages_dO, PdS_stage=num_stages_PdS, SdP_swapAB=SdP_swapAB, dKV_swapAB=dKV_swapAB, dQ_swapAB=dQ_swapAB, AtomLayoutMSdP=AtomLayoutMSdP, AtomLayoutNdKV=AtomLayoutNdKV, AtomLayoutMdQ=AtomLayoutMdQ, num_threads=num_threads, V_in_regs=V_in_regs, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, dQ_single_wg=dQ_single_wg, ) else: fa_bwd_obj = FlashAttentionBackwardSm100( head_dim, head_dim_v, is_causal=causal, is_local=local, qhead_per_kvhead=qhead_per_kvhead, tile_m=m_block_size, tile_n=n_block_size, cluster_size=cluster_size, use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, subtile_factor=subtile_factor, ) # Block sparse tensors for backward use Q-direction indexing (transposed from forward). sparse_tensors_compile = None if normalized_block_sparse_tensors is not None: sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( fa_bwd_obj, q_tensor, k_tensor, v_tensor, do_tensor, lse_log2_tensor, dpsum_tensor, dq_accum_tensor, dk_tensor if not dKV_postprocess else dk_accum_tensor, dv_tensor if not dKV_postprocess else dv_accum_tensor, softmax_scale, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor, cute_aux_tensors, sparse_tensors_compile, current_stream, options="--enable-tvm-ffi", ) if not is_fake_mode(): _flash_attn_bwd.compile_cache[compile_key]( q.detach(), k.detach(), v.detach(), dout, lse_log2, dpsum, dq_accum, dk if not dKV_postprocess else dk_accum, dv if not dKV_postprocess else dv_accum, softmax_scale, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, dK_semaphore, dV_semaphore, aux_tensors, normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, ) if arch // 10 == 9: # dQ postprocess: match main kernel's MMA WG count, unless dQ_single_wg num_threads_post_dQ = 128 if dQ_single_wg else cfg.num_wg * 128 num_threads_post_dKV = cfg.num_wg * 128 else: num_threads_post_dQ = 128 num_threads_post_dKV = 128 # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 _bwd_postprocess_convert( dq_accum, dq, softmax_scale, cu_seqlens_q, seqused_q, arch, dtype, head_dim, m_block_size, num_threads_post_dQ, AtomLayoutMdQ, dQ_swapAB, use_2cta_instrs=use_2cta_instrs, cluster_size=1, ) if dKV_postprocess: # Postprocess: convert dk_accum from float32 to dk in bf16/fp16 _bwd_postprocess_convert( dk_accum, dk, softmax_scale, cu_seqlens_k, seqused_k, arch, dtype, head_dim, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, ) # Postprocess: convert dv_accum from float32 to dv in bf16/fp16 _bwd_postprocess_convert( dv_accum, dv, 1.0, cu_seqlens_k, seqused_k, arch, dtype, head_dim_v, n_block_size, num_threads_post_dKV, AtomLayoutNdKV, dKV_swapAB, cluster_size=cluster_size, ) return dq, dk, dv _flash_attn_bwd.compile_cache = get_jit_cache("bwd") class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, return_lse: bool = False, ): # Only create block sparse tensors if at least one block sparse parameter is provided block_sparse_tensors = None if any(t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx]): block_sparse_tensors = BlockSparseTensorsTorch( full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, mask_block_cnt=mask_block_cnt, mask_block_idx=mask_block_idx, block_size=block_size, ) out, lse = _flash_attn_fwd( q, k, v, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, mask_mod=mask_mod, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.return_lse = return_lse ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse = ctx.saved_tensors if not ctx.return_lse: dlse = None if dout is None: dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, v, out, dout, lse, ctx.softmax_scale, ctx.causal, ctx.softcap, window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], deterministic=ctx.deterministic, dlse=dlse, ) return dq, dk, dv, *((None,) * 20) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor], cu_seqlens_k: Optional[torch.Tensor], seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): out, lse = _flash_attn_fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, score_mod=score_mod, aux_tensors=aux_tensors, return_lse=return_lse, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.deterministic = deterministic ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.return_lse = return_lse ctx.set_materialize_grads(False) return out, lse @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors assert ctx.softcap == 0.0 if not ctx.return_lse: dlse = None if dout is None: dout = torch.zeros_like(out) dq, dk, dv = _flash_attn_bwd( q, k, v, out, dout, lse, ctx.softmax_scale, ctx.causal, ctx.softcap, window_size_left=ctx.window_size[0], window_size_right=ctx.window_size[1], cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_k=ctx.max_seqlen_k, deterministic=ctx.deterministic, dlse=dlse, ) return dq, dk, dv, *((None,) * 20) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, full_block_cnt: Optional[torch.Tensor] = None, full_block_idx: Optional[torch.Tensor] = None, mask_block_cnt: Optional[torch.Tensor] = None, mask_block_idx: Optional[torch.Tensor] = None, block_size: Optional[Tuple[int, int]] = None, return_lse: bool = False, ): return FlashAttnFunc.apply( q, k, v, softmax_scale, causal, window_size, learnable_sink, softcap, num_splits, pack_gqa, deterministic, mask_mod, full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx, block_size, return_lse, ) def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, score_mod: Optional[Callable] = None, aux_tensors: Optional[list] = None, return_lse: bool = False, ): return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, softmax_scale, causal, window_size, learnable_sink, softcap, num_splits, pack_gqa, deterministic, score_mod, aux_tensors, return_lse, ) def _compile_fwd_combine( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, has_cu_seqlens, has_seqused, has_lse, has_varlen_batch_idx, ): """Compile fwd combine kernel using cute fake tensors (no real GPU tensors needed).""" sym = cute.sym_int div = 128 // dtype_partial.width # 16-byte alignment in elements fa_combine = FlashAttentionForwardCombine( dtype=dtype, dtype_partial=dtype_partial, head_dim=head_dim, tile_m=tile_m, k_block_size=k_block_size, log_max_splits=log_max_splits, ) if not fa_combine.can_implement( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, num_threads=256, ): raise RuntimeError( "FlashAttention combine kernel cannot be implemented with given parameters" ) if has_cu_seqlens: # Varlen: (num_splits, total_q, nheads, headdim) num_splits, total_q, nheads = sym(), sym(), sym() mO_partial = fake_tensor(dtype_partial, (num_splits, total_q, nheads, head_dim), divisibility=div) mLSE_partial = fake_tensor(Float32, (num_splits, total_q, nheads), divisibility=1, leading_dim=1) mO = fake_tensor(dtype, (total_q, nheads, head_dim), divisibility=div) mLSE = fake_tensor(Float32, (total_q, nheads), divisibility=1, leading_dim=0) if has_lse else None else: # Batched: (num_splits, batch, seqlen, nheads, headdim) num_splits, batch, seqlen, nheads = sym(), sym(), sym(), sym() mO_partial = fake_tensor(dtype_partial, (num_splits, batch, seqlen, nheads, head_dim), divisibility=div) mLSE_partial = fake_tensor(Float32, (num_splits, batch, seqlen, nheads), divisibility=1, leading_dim=2) mO = fake_tensor(dtype, (batch, seqlen, nheads, head_dim), divisibility=div) mLSE = fake_tensor(Float32, (batch, seqlen, nheads), divisibility=1, leading_dim=1) if has_lse else None batch = mO_partial.shape[1] batch_for_1d = batch if not has_cu_seqlens else sym() batchp1 = sym() mCuSeqlens = fake_tensor(Int32, (batchp1,), divisibility=1) if has_cu_seqlens else None mSeqused = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_seqused else None mNumSplitsDynamic = None # Not parametrized in compile_key mVarlenBatchIdx = fake_tensor(Int32, (batch_for_1d,), divisibility=1) if has_varlen_batch_idx else None mSemaphore = None # Not parametrized in compile_key return cute.compile( fa_combine, mO_partial, mLSE_partial, mO, mLSE, mCuSeqlens, mSeqused, mNumSplitsDynamic, mVarlenBatchIdx, mSemaphore, cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), options="--enable-tvm-ffi", ) def _flash_attn_fwd_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor, lse: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, num_splits_dynamic_ptr: Optional[torch.Tensor] = None, varlen_batch_idx: Optional[torch.Tensor] = None, semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. Combines partial outputs and log-sum-exp values from multiple splits of attention computation into final outputs. Args: out_partial: Partial outputs tensor (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim) if there's cu_seqlens lse_partial: Partial LSE tensor (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads) if there's cu_seqlens out: Output tensor (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim) if there's cu_seqlens lse: Output LSE tensor (batch, seqlen, nheads) or (total_q, nheads) if there's cu_seqlens. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch num_splits_dynamic_ptr: Dynamic number of splits per batch semaphore_to_reset: Semaphore for synchronization k_block_size: Block size for head dimension Returns: None """ assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( "out_partial must be fp16, bf16, or fp32" ) if not is_fake_mode(): assert out_partial.is_cuda and lse_partial.is_cuda, "tensors must be on CUDA device" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 # Validate optional tensors for t, name in [ (cu_seqlens, "cu_seqlens"), (seqused, "seqused"), (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), ]: if t is not None: if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA device" assert t.is_contiguous(), f"{name} must be contiguous" head_dim = out_partial.shape[-1] num_splits = out_partial.shape[0] assert num_splits <= 256 # If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively # so that kBlockM is smaller and we have more parallelism. k_block_size = 64 if head_dim <= 64 else 128 # We want kBlockM to be as small as possible to maximize parallelism. # E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). tile_m = 8 if k_block_size % 128 == 0 else (16 if k_block_size % 64 == 0 else 32) log_max_splits = max(math.ceil(math.log2(num_splits)), 4) if tile_m == 8: # If kBlockM == 8 then the minimum number of splits is 32. # TODO: we can deal w this by using 128 threads instead log_max_splits = max(log_max_splits, 5) # Create combine kernel configuration dtype = torch2cute_dtype_map[out.dtype] dtype_partial = torch2cute_dtype_map[out_partial.dtype] compile_key = ( dtype, dtype_partial, head_dim, tile_m, k_block_size, log_max_splits, cu_seqlens is not None, seqused is not None, lse is not None, varlen_batch_idx is not None, ) if compile_key not in _flash_attn_fwd_combine.compile_cache: _flash_attn_fwd_combine.compile_cache[compile_key] = _compile_fwd_combine( *compile_key ) if not is_fake_mode(): _flash_attn_fwd_combine.compile_cache[compile_key]( out_partial, lse_partial, out, lse, cu_seqlens, seqused, num_splits_dynamic_ptr, varlen_batch_idx, semaphore_to_reset, ) _flash_attn_fwd_combine.compile_cache = get_jit_cache("fwd_combine") def flash_attn_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, cu_seqlens: Optional[torch.Tensor] = None, seqused: Optional[torch.Tensor] = None, varlen_batch_idx: Optional[torch.Tensor] = None, return_lse: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. Combines partial outputs and log-sum-exp values from multiple splits of attention computation into final outputs. This is the main user-facing interface for the combine kernel. Args: out_partial: Partial outputs tensor with shape: - (num_splits, batch_size, seqlen, num_heads, head_size) for regular batched input - (num_splits, total_q, num_heads, head_size) for variable length input lse_partial: Partial LSE tensor with shape: - (num_splits, batch_size, seqlen, num_heads) for regular batched input - (num_splits, total_q, num_heads) for variable length input out: Optional output tensor. If None, will be created automatically. out_dtype: Optional output dtype. If None, will use fp16/bf16 based on input. cu_seqlens: Cumulative sequence lengths for variable length sequences seqused: Used sequence lengths for each batch varlen_batch_idx: Optional mapping from virtual batch index to real batch index (int32 tensor of shape (batch_size,)). Used by persistent tile schedulers that reorder batch processing for load balancing. return_lse: Whether to return the combined LSE tensor. Default is True. Returns: Tuple of (out, lse) where: - out: Combined output tensor with shape (batch_size, seqlen, num_heads, head_size) or (total_q, num_heads, head_size) for varlen - lse: Combined log-sum-exp tensor with shape (batch_size, seqlen, num_heads) or (total_q, num_heads) for varlen. None if return_lse=False Note: This function expects the input tensors to be in the format produced by split attention computation, where the first dimension is num_splits. The permuting from user format to kernel format is now done inside the kernel. """ # Input validation assert out_partial.dim() in [4, 5], "out_partial must have 4 or 5 dimensions" # Determine if this is variable length based on dimensions is_varlen = out_partial.dim() == 4 if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape batch_size = 1 # Treat as single batch for varlen seqlen = total_q else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape # Determine output dtype if out_dtype is None: out_dtype = out_partial.dtype # Create output if not provided device = out_partial.device if out is None: if is_varlen: out = torch.empty(total_q, num_heads, head_size, dtype=out_dtype, device=device) else: out = torch.empty( batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device ) # Create lse output only if requested if return_lse: if is_varlen: lse = torch.empty(num_heads, total_q, dtype=torch.float32, device=device) else: lse = torch.empty(batch_size, num_heads, seqlen, dtype=torch.float32, device=device) lse = lse.transpose(-1, -2) else: lse = None _flash_attn_fwd_combine( out_partial, lse_partial, out, lse, cu_seqlens, seqused, varlen_batch_idx=varlen_batch_idx, ) return out, lse ================================================ FILE: flash_attn/cute/mask.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Optional, Callable, TypeAlias from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass import Float32, Int32, Uint32, const_expr from quack import layout_utils import flash_attn.cute.utils as utils from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 @cute.jit def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: """32-bit R2P bitmask keeping positions < limit (exclusive upper bound). Positions 0..limit-1 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). Uses inline PTX to avoid shift-by-type-width UB. """ m = max((s + 1) * MASK_R2P_CHUNK_SIZE - limit, 0) return utils.shr_u32(Uint32(0xFFFFFFFF), Uint32(m)) @cute.jit def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: """32-bit R2P bitmask keeping positions >= limit (inclusive lower bound). Positions limit..31 in chunk `s` get bit=1 (keep), the rest bit=0 (mask). Uses inline PTX to avoid shift-by-type-width UB. """ n = max(limit - s * MASK_R2P_CHUNK_SIZE, 0) return utils.shl_u32(Uint32(0xFFFFFFFF), Uint32(n)) @cute.jit def mask_r2p_lambda( X: cute.Tensor, mask_gen_fn: cutlass.Constexpr[MaskGenFn], rank1: bool = False, ) -> None: """Apply R2P masking with a custom bitmask generator. mask_gen_fn(chunk_idx: constexpr int) -> Uint32: Returns a 32-bit bitmask for the chunk. Bit i set means column chunk_idx * chunk_size + i is KEPT; bit i clear means masked to -inf. """ ncol = const_expr(cute.size(X.shape[cute.rank(X) - 1]) if not rank1 else cute.size(X.shape)) # 32-column chunks. The mask_gen_fn returns a Uint32 bitmask (1=keep). CHUNK_SIZE = MASK_R2P_CHUNK_SIZE for s in cutlass.range_constexpr(cute.ceil_div(ncol, CHUNK_SIZE)): mask = mask_gen_fn(s) # This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction for i in cutlass.range_constexpr(min(CHUNK_SIZE, ncol - s * CHUNK_SIZE)): in_bound = cutlass.Boolean(mask & (Uint32(1) << i)) c = s * CHUNK_SIZE + i if const_expr(rank1): X[c] = X[c] if in_bound else -Float32.inf else: for r in cutlass.range_constexpr(cute.size(X.shape[0])): X[r, c] = X[r, c] if in_bound else -Float32.inf @cute.jit def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32: """Transform SM90 MMA column coordinate to R2P element index. SM90 MMA accumulator column indices are non-contiguous: 0, 1, 8, 9, 16, 17, ... Element indices are contiguous: 0, 1, 2, 3, 4, 5, ... This converts a column-space threshold to element-space for r2p_bitmask_below/above. """ return col_limit // 8 * 2 + min(col_limit % 8, 2) @cute.jit def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: """Convert a row coordinate to an R2P element index in the warp-group interleaved layout. In the SM100 backward pass, 2 warp groups share TMEM. The TMEM load atom distributes rows in an interleaved pattern: elements 0..num_rep-1 map to rows 0..num_rep-1 (warp group 0), elements num_rep..2*num_rep-1 map to rows num_rep*num_wg..num_rep*num_wg+num_rep-1 (warp group 1), and so on. Row-coordinate thresholds (causal limits, window bounds, uih_len) must be converted to element indices before use with r2p_bitmask_above/below. Rows not owned by this thread (in the gap between warp groups) are clamped to the boundary element index, which is safe because R2P thresholds are monotonic. Example with num_rep=16, num_wg=2: row 0 -> elem 0, row 15 -> elem 15, row 16 -> elem 16 (clamped), row 31 -> elem 16 (clamped), row 32 -> elem 16, row 33 -> elem 17, row 47 -> elem 31. """ return x // (num_rep * num_wg) * num_rep + min(x % (num_rep * num_wg), num_rep) @dataclass(frozen=True) class AttentionMask: tile_m: cutlass.Constexpr[int] tile_n: cutlass.Constexpr[int] seqlen_info: SeqlenInfoQK window_size_left: Optional[Int32] = None window_size_right: Optional[Int32] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 # only pass in if we're doing PackGQA swap_AB: cutlass.Constexpr[bool] = False @property def seqlen_q(self) -> Int32: return self.seqlen_info.seqlen_q @property def seqlen_k(self) -> Int32: return self.seqlen_info.seqlen_k @cute.jit def apply_mask( self, acc_S: cute.Tensor, batch_idx: cutlass.Int32, head_idx: cutlass.Int32, m_block: cutlass.Int32, n_block: cutlass.Int32, thr_mma: cute.TiledMma, mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB) acc_shape = (self.tile_m, self.tile_n) cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB) # We use t0ScS as these indices are known at compile time. We then must subtract the # column limit by the thread column offset. t0ScS_mn = layout_utils.reshape_acc_to_mn( thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB ) ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_mn[0][COL] # To handle edge cases of completely masked out rows where n_block_max = 0, # we treat negative n_blocks as 0th n_block # TODO: find more transparent solution if n_block < 0: n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): r2p = const_expr(not self.swap_AB) if const_expr(not r2p): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): oob = t0ScS_mn[0, c][COL] >= seqlenk_col_limit for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = -Float32.inf if oob else acc_S_mn[r, c] else: seqlenk_col_limit_r2p = sm90_col_to_r2p_idx(seqlenk_col_limit) mask_r2p_lambda(acc_S_mn, lambda s: r2p_bitmask_below(seqlenk_col_limit_r2p, s)) elif const_expr( not mask_causal and not mask_local and mask_mod is not None ): # FlexAttention mask mod nrow = const_expr(cute.size(tScS_mn.shape[0])) ncol = const_expr(cute.size(tScS_mn.shape[1])) has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None ) wrap_aux_indices = const_expr( has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) ) for r in cutlass.range_constexpr(nrow): # Respect swap_AB: ROW/COL determine which coordinate component corresponds to Q/KV. local_row = tScS_mn[r, 0][ROW] global_row_idx = local_row + m_block * self.tile_m row_for_mod = global_row_idx head_idx_for_mod = head_idx if const_expr(self.qhead_per_kvhead_packgqa != 1): head_offset = global_row_idx % self.qhead_per_kvhead_packgqa head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset row_for_mod = global_row_idx // self.qhead_per_kvhead_packgqa row_for_seqlen = row_for_mod if const_expr(wrap_aux_indices): _, row_for_mod = divmod(row_for_mod, fastdiv_mods[0]) for col in cutlass.range_constexpr(ncol): col_idx_local = t0ScS_mn[0, col][COL] # Convert to absolute column index global_col_idx = thr_col_offset + col_idx_local + n_block * self.tile_n col_for_mod = global_col_idx if const_expr(wrap_aux_indices): _, col_for_mod = divmod(global_col_idx, fastdiv_mods[1]) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) q_idx_ssa = utils.scalar_to_ssa(row_for_mod, cutlass.Int32) kv_idx_ssa = utils.scalar_to_ssa(col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, head_idx_ssa, q_idx_ssa, kv_idx_ssa, self.seqlen_info, aux_tensors, ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) if const_expr(mask_seqlen): out_of_bounds = (row_for_seqlen >= self.seqlen_q) or ( global_col_idx >= self.seqlen_k ) if out_of_bounds: acc_S_mn[r, col] = -cutlass.Float32.inf else: acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf else: acc_S_mn[r, col] = acc_S_mn[r, col] if cond else -cutlass.Float32.inf else: # Causal or local if const_expr(not self.swap_AB): # If PackGQA, we split the work of compute divmod among threads in the same row threads_per_row = thr_mma.tv_layout_C.shape[0][0] mma_m_idx = None if const_expr(self.qhead_per_kvhead_packgqa != 1): assert not self.swap_AB, "swap_AB with PackGQA not supported yet" assert cute.arch.WARP_SIZE % threads_per_row == 0, ( "threads_per_row must divide WARP_SIZE" ) assert cute.size(acc_S_mn.shape[0]) <= threads_per_row tidx = thr_mma.thr_idx mma_m_idx = ( m_block * self.tile_m + tScS_mn[tidx % threads_per_row, 0][0] ) // self.qhead_per_kvhead_packgqa causal_row_offset = ( 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q - thr_col_offset ) if const_expr(mask_causal): r2p = const_expr(not self.swap_AB) # R2P trick, see apply_mask_sm100 for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. if const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m else: row_idx = utils.shuffle_sync( mma_m_idx, r % threads_per_row, width=threads_per_row ) col_limit_right = row_idx + causal_row_offset if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) if const_expr(not r2p): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): acc_S_mn[r, c] = ( -Float32.inf if t0ScS_mn[0, c][1] >= col_limit_right else acc_S_mn[r, c] ) else: col_limit_r2p = sm90_col_to_r2p_idx(col_limit_right) mask_r2p_lambda( acc_S_mn[r, None], lambda s: r2p_bitmask_below(col_limit_r2p, s), rank1=True, ) else: # Local local_row_offset_right = ( causal_row_offset + self.window_size_right if const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( causal_row_offset - 1 - self.window_size_left if const_expr(self.window_size_left is not None) else None ) r2p_local = const_expr(not self.swap_AB) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): if const_expr(self.qhead_per_kvhead_packgqa == 1): row_idx = tScS_mn[r, 0][0] + m_block * self.tile_m else: row_idx = utils.shuffle_sync( mma_m_idx, r % threads_per_row, width=threads_per_row ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right else: col_limit_right = self.tile_n if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) else 0 ) if const_expr(not r2p_local): # traverse column index. for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): col_idx = t0ScS_mn[0, c][1] if col_idx >= col_limit_right or col_idx < col_limit_left: acc_S_mn[r, c] = -Float32.inf else: col_limit_right_r2p = sm90_col_to_r2p_idx(col_limit_right) col_limit_left_r2p = sm90_col_to_r2p_idx(col_limit_left) def mask_gen_fn(s: int) -> Uint32: return r2p_bitmask_below( col_limit_right_r2p, s ) & r2p_bitmask_above(col_limit_left_r2p, s) mask_r2p_lambda(acc_S_mn[r, None], mask_gen_fn, rank1=True) else: # swap_AB assert self.qhead_per_kvhead_packgqa == 1 thr_row_offset = tScS_mn[0][ROW] causal_row_offset = ( seqlenk_col_limit - self.seqlen_q + m_block * self.tile_m + thr_row_offset ) if const_expr(mask_causal): for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): col0 = t0ScS_mn[0, c][COL] # If col0 is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. row_limit_top = ( self.tile_m if col0 >= seqlenk_col_limit and mask_seqlen else col0 - causal_row_offset ) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): acc_S_mn[r, c] = ( -Float32.inf if t0ScS_mn[r, 0][ROW] < row_limit_top else acc_S_mn[r, c] ) else: for c in cutlass.range(cute.size(tScS_mn.shape[1]), unroll_full=True): col0 = t0ScS_mn[0, c][COL] # If col0 is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. row_limit_top = ( self.tile_m if col0 >= seqlenk_col_limit and mask_seqlen else ( col0 - causal_row_offset - self.window_size_right if const_expr(self.window_size_right is not None) else 0 ) ) row_limit_bot = ( col0 - causal_row_offset + self.window_size_left if const_expr(self.window_size_left is not None) else self.tile_m ) for r in cutlass.range(cute.size(tScS_mn.shape[0]), unroll_full=True): row_idx = t0ScS_mn[r, 0][ROW] acc_S_mn[r, c] = ( -Float32.inf if row_idx < row_limit_top or row_idx > row_limit_bot else acc_S_mn[r, c] ) @cute.jit def apply_mask_sm100( self, acc_S: cute.Tensor, m_block: Int32, n_block: Int32, thr_mma: cute.TiledMma, thr_tmem_load: cute.TiledCopy, mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, batch_idx: Int32 = None, head_idx: Int32 = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, check_q_boundary: bool = False, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1]) tScS = thr_mma.partition_C(cS) tScS = tScS[(None, None), 0, 0] tScS_t2r = thr_tmem_load.partition_D(tScS) # To handle edge cases of completely masked out rows where n_block_max = 0, # we treat negative n_blocks as 0th n_block # TODO: find more transparent solution if n_block < 0: n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): # if tScS_t2r[i][1] >= seqlenk_col_limit: # acc_S[i] = -Float32.inf # For some reason the 2 lines above generate really bad SASS acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: mask_r2p_lambda( acc_S, lambda s: r2p_bitmask_below(seqlenk_col_limit, s), rank1=True, ) elif const_expr(not mask_causal and not mask_local and mask_mod is not None): # Block sparse case w/ mask_mod has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None ) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) for i in cutlass.range_constexpr(ncol): row_coord = tScS_t2r[i][0] if not self.swap_AB else tScS_t2r[i][1] col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] global_row = row_coord + m_block * self.tile_m global_col = col_coord + n_block * self.tile_n if const_expr(self.qhead_per_kvhead_packgqa != 1): assert head_divmod is not None mask_row, head_offset = divmod(global_row, head_divmod) head_idx_for_mod = head_idx * self.qhead_per_kvhead_packgqa + head_offset else: head_idx_for_mod = head_idx mask_row = global_row mask_row_for_mod = mask_row if const_expr(has_fastdiv and aux_tensors is not None): if check_q_boundary: _, mask_row_for_mod = divmod(mask_row, fastdiv_mods[0]) global_col_for_mod = global_col if const_expr(has_fastdiv and mask_seqlen and aux_tensors is not None): _, global_col_for_mod = divmod(global_col, fastdiv_mods[1]) head_idx_ssa = utils.scalar_to_ssa(head_idx_for_mod, cutlass.Int32) mask_row_ssa = utils.scalar_to_ssa(mask_row_for_mod, cutlass.Int32) kv_idx_ssa = utils.scalar_to_ssa(global_col_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, head_idx_ssa, mask_row_ssa, kv_idx_ssa, self.seqlen_info, aux_tensors, ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) acc_S[i] = acc_S[i] if cond else -Float32.inf if const_expr(mask_seqlen): acc_S[i] = -Float32.inf if global_col >= self.seqlen_k else acc_S[i] if check_q_boundary: acc_S[i] = -Float32.inf if mask_row >= self.seqlen_q else acc_S[i] else: # Causal or local causal_row_offset = self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m if const_expr(self.qhead_per_kvhead_packgqa != 1): row_idx = row_idx // self.qhead_per_kvhead_packgqa if const_expr(mask_causal): col_limit_right = row_idx + causal_row_offset + 1 if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) # if cute.arch.thread_idx()[0] % 32 == 0: # cute.printf("tidx = %d, tidx tmem = %d, row_idx = %d, col_limit_right = %d, causal_row_offset = %d\n", cute.arch.thread_idx()[0], thr_tmem_load.thr_idx, row_idx, col_limit_right, causal_row_offset) ncol = const_expr(cute.size(tScS_t2r.shape)) if const_expr(not r2p): for i in cutlass.range(ncol, unroll_full=True): acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= col_limit_right else acc_S[i] else: mask_r2p_lambda( acc_S, lambda s: r2p_bitmask_below(col_limit_right, s), rank1=True, ) else: local_row_offset_right = ( causal_row_offset + 1 + self.window_size_right if const_expr(self.window_size_right is not None) else None ) local_row_offset_left = ( causal_row_offset - self.window_size_left if const_expr(self.window_size_left is not None) else None ) if const_expr(self.window_size_right is not None): col_limit_right = row_idx + local_row_offset_right else: col_limit_right = self.tile_n if const_expr(mask_seqlen): col_limit_right = cutlass.min(col_limit_right, seqlenk_col_limit) col_limit_left = ( row_idx + local_row_offset_left if const_expr(self.window_size_left is not None) else 0 ) if const_expr(not r2p): # if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left) for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): col_idx = tScS_t2r[i][1] acc_S[i] = ( -Float32.inf if col_idx >= col_limit_right or col_idx < col_limit_left else acc_S[i] ) else: # Dual-bound R2P masking for SM100. # Masks elements where: NOT (col_limit_left <= col < col_limit_right) def mask_gen_fn(s: int) -> Uint32: return r2p_bitmask_below(col_limit_right, s) & r2p_bitmask_above( col_limit_left, s ) mask_r2p_lambda(acc_S, mask_gen_fn, rank1=True) @cute.jit def apply_mask_sm100_transposed( self, acc_S: cute.Tensor, tScS_t2r: cute.Tensor, t0ScS_t2r: cute.Tensor, m_block: cutlass.Int32, n_block: cutlass.Int32, mask_seqlen: cutlass.Constexpr, mask_causal: cutlass.Constexpr, mask_local: cutlass.Constexpr, mask_mod: cutlass.Constexpr[Optional[Callable]] = None, batch_idx: Int32 = None, head_idx: Int32 = None, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), is_full_block: bool = False, check_m_boundary: bool = True, ) -> None: """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. Coordinate conventio: - ROW corresponds to Q (m_block) - COL corresponds to KV (n_block) is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking. check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks). When iterating m_blocks in forward order, only the last m_block may be partial. """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 # assert t0ScS_t2r[0][COL] == 0, "col0 == 0" # tmp comment for 2-cta bwd thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset if const_expr(not mask_causal and not mask_local and mask_mod is not None): # Block sparse case with mask_mod (backward) # # Coordinate convention: ROW → Q (m_block), COL → KV (n_block). # These already account for swap_AB. # # FULL blocks: mask_mod returns True for all elements, so skip it. # Still need seqlen bounds check (elements may be OOB on last m_block). # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds. if is_full_block: if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: # Entire tile is OOB for K for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = -cutlass.Float32.inf elif check_m_boundary: # Last m_block: check Q and K boundaries ncol = const_expr(cute.size(tScS_t2r.shape)) for i in cutlass.range_constexpr(ncol): row_coord = tScS_t2r[i][ROW] col_coord = tScS_t2r[i][COL] global_q = row_coord + m_block * self.tile_m global_kv = col_coord + n_block * self.tile_n q_out_of_bounds = global_q >= self.seqlen_q kv_out_of_bounds = global_kv >= self.seqlen_k out_of_bounds = q_out_of_bounds or kv_out_of_bounds acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] else: # Partial block has_fastdiv = const_expr( fastdiv_mods is not None and fastdiv_mods[0] is not None and fastdiv_mods[1] is not None ) wrap_aux_indices = const_expr( has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) ) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) ncol = const_expr(cute.size(tScS_t2r.shape)) for i in cutlass.range_constexpr(ncol): row_coord = tScS_t2r[i][ROW] col_coord = tScS_t2r[i][COL] global_q = row_coord + m_block * self.tile_m global_kv = col_coord + n_block * self.tile_n q_idx_for_mod = global_q kv_idx_for_mod = global_kv if const_expr(wrap_aux_indices): _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0]) _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1]) q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32) kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32) mask_value = mask_mod( batch_idx_ssa, head_idx_ssa, q_idx_ssa, kv_idx_ssa, self.seqlen_info, aux_tensors, ) cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf if const_expr(mask_seqlen): # check_m_boundary=False skips q check for non-boundary m_blocks q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q) kv_out_of_bounds = global_kv >= self.seqlen_k out_of_bounds = q_out_of_bounds or kv_out_of_bounds acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] elif const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = -cutlass.Float32.inf else: # Causal or local thr_row_offset = tScS_t2r[0][ROW] seqlenq_row_limit = self.seqlen_q - m_block * self.tile_m - thr_row_offset causal_offset = seqlenq_row_limit - seqlenk_col_limit if const_expr(mask_causal): # tidx = cute.arch.thread_idx()[0] % 256 # if tidx < 32: # cute.printf("tidx = {}, {} {}, {} {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1]) row_limit_top = causal_offset if const_expr(mask_seqlen): # If col is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. if seqlenk_col_limit <= 0: row_limit_top = self.tile_m r2p = True if const_expr(not r2p): for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): acc_S[i] = ( -cutlass.Float32.inf if t0ScS_t2r[i][ROW] < row_limit_top else acc_S[i] ) else: num_rep = cute.size(tScS_t2r, mode=[0]) # 16 or 32 num_wg = 2 row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) mask_r2p_lambda( acc_S, lambda s: r2p_bitmask_above(row_limit, s), rank1=True, ) else: if const_expr(self.window_size_right is not None): row_limit_top = causal_offset - self.window_size_right else: row_limit_top = 0 if const_expr(self.window_size_left is not None): row_limit_bot = causal_offset + self.window_size_left if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: row_limit_top = self.tile_m r2p = True if const_expr(not r2p): for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): row_idx = t0ScS_t2r[i][ROW] local_mask = row_idx < row_limit_top if const_expr(self.window_size_left is not None): local_mask |= row_idx > row_limit_bot acc_S[i] = -cutlass.Float32.inf if local_mask else acc_S[i] else: def mask_gen_fn(s: int) -> Uint32: num_rep = cute.size(tScS_t2r, mode=[0]) num_wg = 2 row_limit = row_to_r2p_idx(row_limit_top, num_rep, num_wg) mask = r2p_bitmask_above(row_limit, s) if const_expr(self.window_size_left is not None): row_limit_bottom = row_to_r2p_idx(row_limit_bot + 1, num_rep, num_wg) mask = mask & r2p_bitmask_below(row_limit_bottom, s) return mask mask_r2p_lambda( acc_S, mask_gen_fn, rank1=True, ) ================================================ FILE: flash_attn/cute/mma_sm100_desc.py ================================================ # Copyright (c) 2025, Tri Dao. # Ported Cutlass code from C++ to Python: # https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm100_desc.hpp # https://github.com/NVIDIA/cutlass/blob/main/include/cute/atom/mma_traits_sm100.hpp from enum import IntEnum import cutlass import cutlass.cute as cute # --------------------------------------------------------------------------- # Enumerations that match the HW encodings (values MUST stay identical) # --------------------------------------------------------------------------- class Major(IntEnum): # matrix “layout” in the ISA docs K = 0 MN = 1 class ScaleIn(IntEnum): # negate flags One = 0 Neg = 1 class Saturate(IntEnum): False_ = 0 True_ = 1 class CFormat(IntEnum): # 2-bit field (bits 4-5) F16 = 0 F32 = 1 S32 = 2 class F16F32Format(IntEnum): # 3-bit field (A/B element type) F16 = 0 BF16 = 1 TF32 = 2 class S8Format(IntEnum): UINT8 = 0 INT8 = 1 class MXF8F6F4Format(IntEnum): E4M3 = 0 E5M2 = 1 E2M3 = 3 E3M2 = 4 E2M1 = 5 class MaxShift(IntEnum): NoShift = 0 MaxShift8 = 1 MaxShift16 = 2 MaxShift32 = 3 # --------------------------------------------------------------------------- # CUTLASS-type → encoding helpers # --------------------------------------------------------------------------- def to_UMMA_format(cutlass_type) -> int: """ Map a CUTLASS scalar class to the 3-bit encoding for Matrix A/B. """ if cutlass_type is cutlass.Int8: return S8Format.INT8 # Unsigned 8-bit (if available in your CUTLASS build) if cutlass_type is cutlass.Uint8: return S8Format.UINT8 # FP-16 / BF-16 if cutlass_type is cutlass.Float16: return F16F32Format.F16 if cutlass_type is cutlass.BFloat16: return F16F32Format.BF16 # TensorFloat-32 (8-bit exponent, 10-bit mantissa packed in 19 bits) if cutlass_type is cutlass.TFloat32: return F16F32Format.TF32 # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them if cutlass_type is cutlass.FloatE4M3FN: return MXF8F6F4Format.E4M3 if cutlass_type is cutlass.FloatE5M2: return MXF8F6F4Format.E5M2 raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") def to_C_format(cutlass_type) -> int: """ Map a CUTLASS scalar class to the 2-bit accumulator encoding. """ if cutlass_type is cutlass.Float16: return CFormat.F16 if cutlass_type is cutlass.Float32: return CFormat.F32 if cutlass_type is cutlass.Int32: return CFormat.S32 raise TypeError(f"Unsupported CUTLASS scalar type for accumulator: {cutlass_type!r}") # --------------------------------------------------------------------------- # The constructor – accepts only CUTLASS scalar classes # --------------------------------------------------------------------------- def make_instr_desc( a_type, # CUTLASS scalar class, e.g. cutlass.Int8 b_type, c_type, M: int, # 64, 128 or 256 N: int, # 8 … 256 (multiple of 8) a_major: Major, b_major: Major, a_neg: ScaleIn = ScaleIn.One, b_neg: ScaleIn = ScaleIn.One, c_sat: Saturate = Saturate.False_, is_sparse: bool = False, max_shift: MaxShift = MaxShift.NoShift, ) -> int: """ Build the 32-bit instruction descriptor for Blackwell MMA. All matrix/accumulator **types must be CUTLASS scalar classes** – passing integers is forbidden. """ # --- encode element formats ------------------------------------------------- a_fmt = int(to_UMMA_format(a_type)) b_fmt = int(to_UMMA_format(b_type)) c_fmt = int(to_C_format(c_type)) # --- range checks on M/N ----------------------------------------------------- if M not in (64, 128, 256): raise ValueError("M must be 64, 128 or 256") if N < 8 or N > 256 or (N & 7): raise ValueError("N must be a multiple of 8 in the range 8…256") m_dim = M >> 4 # 5-bit field n_dim = N >> 3 # 6-bit field # fmt: off # --- pack the bit-fields ----------------------------------------------------- desc = 0 desc |= (0 & 0x3) << 0 # sparse_id2 (always 0 here) desc |= (int(is_sparse) & 0x1) << 2 # sparse_flag desc |= (int(c_sat) & 0x1) << 3 # saturate desc |= (c_fmt & 0x3) << 4 # c_format desc |= (a_fmt & 0x7) << 7 # a_format desc |= (b_fmt & 0x7) << 10 # b_format desc |= (int(a_neg) & 0x1) << 13 # a_negate desc |= (int(b_neg) & 0x1) << 14 # b_negate desc |= (int(a_major) & 0x1) << 15 # a_major desc |= (int(b_major) & 0x1) << 16 # b_major desc |= (n_dim & 0x3F) << 17 # n_dim (6 bits) desc |= (m_dim & 0x1F) << 24 # m_dim (5 bits) desc |= (int(max_shift) & 0x3) << 30 # max_shift (2 bits) # fmt: on return desc & 0xFFFF_FFFF # ensure 32-bit result def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): return make_instr_desc( op.a_dtype, op.b_dtype, op.acc_dtype, op.shape_mnk[0], op.shape_mnk[1], Major.K if op.a_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, Major.K if op.b_major_mode == cute.nvgpu.tcgen05.mma.OperandMajorMode.K else Major.MN, ) class LayoutType(IntEnum): # occupies the top-3 bits [61:64) SWIZZLE_NONE = 0 # (a.k.a. “INTERLEAVE” in older docs) SWIZZLE_128B_BASE32B = 1 SWIZZLE_128B = 2 SWIZZLE_64B = 4 SWIZZLE_32B = 6 # values 3,5,7 are reserved / illegal for UMMA # --------------------------------------------------------------------------- # Helpers – figure out the SWIZZLE_* family from the tensor layout # --------------------------------------------------------------------------- def _layout_type(swizzle: cute.Swizzle) -> LayoutType: B, M, S = swizzle.num_bits, swizzle.num_base, swizzle.num_shift if M == 4: # Swizzle<*,4,3> if S != 3: raise ValueError("Unexpected swizzle shift – want S==3 for M==4") return { 0: LayoutType.SWIZZLE_NONE, 1: LayoutType.SWIZZLE_32B, 2: LayoutType.SWIZZLE_64B, 3: LayoutType.SWIZZLE_128B, }[B] # KeyError ⇒ invalid B→ raise if M == 5: # Swizzle<2,5,2> (the only legal triple for M==5) if (B, S) != (2, 2): raise ValueError("Only Swizzle<2,5,2> supported for 128B_BASE32B") return LayoutType.SWIZZLE_128B_BASE32B # Any other (M,B,S) triple is not a UMMA-legal shared-memory layout raise ValueError("Unsupported swizzle triple for UMMA smem descriptor") def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, major: Major) -> int: """ Convert a 2-D *shared-memory* Cute layout into the Blackwell 64-bit smem-descriptor, without the smem start address. layout must correspond to layout of an uint128 tensor. """ # ------------------------------------------------------------------ meta layout_type = _layout_type(swizzle) # resolve SWIZZLE_* family VERSION = 1 # bits 46–47 LBO_MODE = 0 # bit 52 BASE_OFFSET = 0 # bits 49–51 (CUTLASS always 0) # ---------------------------------------------------------- strides (units: uint128_t = 16 B) swizzle_atom_mn_size = { LayoutType.SWIZZLE_NONE: 1, LayoutType.SWIZZLE_32B: 2, LayoutType.SWIZZLE_64B: 4, LayoutType.SWIZZLE_128B: 8, LayoutType.SWIZZLE_128B_BASE32B: 8, }[layout_type] if major is Major.MN: swizzle_atom_k_size = 4 if layout_type is LayoutType.SWIZZLE_128B_BASE32B else 8 canonical_layout = cute.logical_divide(layout, (swizzle_atom_mn_size, swizzle_atom_k_size)) if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): raise ValueError("Not a canonical UMMA_MN Layout: Expected profile failure.") stride_00 = canonical_layout.stride[0][0] if layout_type is not LayoutType.SWIZZLE_NONE and stride_00 != 1: raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") stride_10 = canonical_layout.stride[1][0] if stride_10 != swizzle_atom_mn_size: raise ValueError("Not a canonical UMMA_MN Layout: Expected stride failure.") stride_01, stride_11 = canonical_layout.stride[0][1], canonical_layout.stride[1][1] if layout_type is LayoutType.SWIZZLE_NONE: stride_byte_offset, leading_byte_offset = stride_01, stride_11 else: stride_byte_offset, leading_byte_offset = stride_11, stride_01 else: if layout_type == LayoutType.SWIZZLE_128B_BASE32B: raise ValueError("SWIZZLE_128B_BASE32B is invalid for Major-K") if not cute.size(layout.shape[0]) % 8 == 0: raise ValueError("Not a canonical UMMA_K Layout: Expected MN-size multiple of 8.") canonical_layout = cute.logical_divide(layout, (8, 2)) if not cute.is_congruent(canonical_layout, ((1, 1), (1, 1))): raise ValueError("Not a canonical UMMA_K Layout: Expected profile failure.") stride_00 = canonical_layout.stride[0][0] if stride_00 != swizzle_atom_mn_size: raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") stride_10 = canonical_layout.stride[1][0] if layout_type is not LayoutType.SWIZZLE_NONE and stride_10 != 1: raise ValueError("Not a canonical UMMA_K Layout: Expected stride failure.") stride_01 = canonical_layout.stride[0][1] stride_byte_offset, leading_byte_offset = stride_01, stride_10 # ------------------------------------------------------------------ pack desc = 0 # leading_byte_offset_ [16:30) desc |= (leading_byte_offset & 0x3FFF) << 16 # stride_byte_offset_ [32:46) desc |= (stride_byte_offset & 0x3FFF) << 32 # version_ [46:48) desc |= (VERSION & 0x3) << 46 # base_offset_ [49:52) desc |= (BASE_OFFSET & 0x7) << 49 # lbo_mode_ [52:53) desc |= (LBO_MODE & 0x1) << 52 # layout_type_ [61:64) desc |= (int(layout_type) & 0x7) << 61 return desc & 0xFFFF_FFFF_FFFF_FFFF # force 64-bit width def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: # 14 bits, remove 4 LSB (bits 0-13 in desc) return (start_addr.toint() & 0x3FFFF) >> 4 def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: sA_swizzle = sA.iterator.type.swizzle_type return make_smem_desc_base( cute.recast_layout(128, sA.element_type.width, sA.layout[0]), sA_swizzle, major, ) ================================================ FILE: flash_attn/cute/named_barrier.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import enum class NamedBarrierFwd(enum.IntEnum): Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() WarpSchedulerWG1 = enum.auto() WarpSchedulerWG2 = enum.auto() WarpSchedulerWG3 = enum.auto() PFull = enum.auto() PEmpty = enum.auto() class NamedBarrierFwdSm100(enum.IntEnum): Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() TmemPtr = enum.auto() SoftmaxStatsW0 = enum.auto() SoftmaxStatsW1 = enum.auto() SoftmaxStatsW2 = enum.auto() SoftmaxStatsW3 = enum.auto() SoftmaxStatsW4 = enum.auto() SoftmaxStatsW5 = enum.auto() SoftmaxStatsW6 = enum.auto() SoftmaxStatsW7 = enum.auto() class NamedBarrierBwd(enum.IntEnum): Epilogue = enum.auto() WarpSchedulerWG1 = enum.auto() WarpSchedulerWG2 = enum.auto() WarpSchedulerWG3 = enum.auto() PdS = enum.auto() dQFullWG0 = enum.auto() dQFullWG1 = enum.auto() dQFullWG2 = enum.auto() dQEmptyWG0 = enum.auto() dQEmptyWG1 = enum.auto() dQEmptyWG2 = enum.auto() class NamedBarrierBwdSm100(enum.IntEnum): EpilogueWG1 = enum.auto() EpilogueWG2 = enum.auto() Compute = enum.auto() dQaccReduce = enum.auto() TmemPtr = enum.auto() ================================================ FILE: flash_attn/cute/pack_gqa.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Union, Tuple import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync from quack import layout_utils import flash_attn.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): """Reshape a tensor to fold qhead_per_kvhead into the seqlen dimension (mode 0). The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept as-is (e.g. batch). For Q/O tensors (head_idx=2): (seqlen_q, headdim, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) For LSE tensors (head_idx=1): (seqlen_q, nheads, batch, ...) -> ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) """ head_stride = T.stride[head_idx] shape_packed = ( (qhead_per_kvhead, T.shape[0]), *[T.shape[i] for i in range(1, head_idx)], nheads_kv, *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], ) stride_packed = ( (head_stride, T.stride[0]), *[T.stride[i] for i in range(1, head_idx)], head_stride * qhead_per_kvhead, *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], ) return cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) def make_packgqa_tiled_tma_atom( op: cute.atom.CopyOp, gmem_tensor: cute.Tensor, smem_layout: Union[cute.Layout, cute.ComposedLayout], cta_tiler: Tuple[int, int], qhead_per_kvhead: int, head_idx: int, ): # This packing and unpacking of the layout is so that we keep the same TMA dimension as usual. # e.g. for (seqlen, d, nheads, b) layout, we still have 4D TMA after packing to # ((nheads, seqlen), d, b). # If we instead pack directly to ((qhead_per_kvhead, seqlen), d, nheads_kv, b) we'd have 5D TMA. # Pack headdim and seqlen dim into 1: (seqlen, d, nheads, b) -> ((nheads, seqlen), d, b) gmem_tensor = layout_utils.select( gmem_tensor, [head_idx, *range(head_idx), *range(head_idx + 1, cute.rank(gmem_tensor))] ) gmem_tensor = cute.group_modes(gmem_tensor, 0, 2) assert cta_tiler[0] % qhead_per_kvhead == 0, ( "CTA tile size in the seqlen dimension must be divisible by qhead_per_kvhead" ) tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( op, gmem_tensor, smem_layout, ((qhead_per_kvhead, cta_tiler[0] // qhead_per_kvhead), cta_tiler[1]), # No mcast ) # Unpack from ((nheads, seqlen), d, b) -> ((qhead_per_kvhead, seqlen), d, nheads_kv, b) T = tma_tensor shape_packed = ( (qhead_per_kvhead, T.shape[0][1]), *[T.shape[i] for i in range(1, head_idx)], T.shape[0][0] // qhead_per_kvhead, *[T.shape[i] for i in range(head_idx, len(T.shape))], ) stride_packed = ( *[T.stride[i] for i in range(head_idx)], T.stride[0][0] * qhead_per_kvhead, *[T.stride[i] for i in range(head_idx, len(T.shape))], ) tma_tensor = cute.make_tensor(T.iterator, cute.make_layout(shape_packed, stride=stride_packed)) return tma_atom, tma_tensor def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): """Reverse of pack_gqa_layout: unfold qhead_per_kvhead from the seqlen dimension (mode 0). The head dimension is at mode ``head_idx``. Modes before it (1..head_idx-1) are kept as-is (e.g. headdim for Q/O tensors), and modes after it are kept as-is (e.g. batch). For Q/O tensors (head_idx=2): ((qhead_per_kvhead, seqlen_q), headdim, nheads_kv, batch, ...) -> (seqlen_q, headdim, nheads, batch, ...) For LSE tensors (head_idx=1): ((qhead_per_kvhead, seqlen_q), nheads_kv, batch, ...) -> (seqlen_q, nheads, batch, ...) """ seqlen_stride = T.stride[0][1] head_stride = T.stride[0][0] shape_unpacked = ( T.shape[0][1], *[T.shape[i] for i in range(1, head_idx)], T.shape[head_idx] * qhead_per_kvhead, *[T.shape[i] for i in range(head_idx + 1, len(T.shape))], ) stride_unpacked = ( seqlen_stride, *[T.stride[i] for i in range(1, head_idx)], head_stride, *[T.stride[i] for i in range(head_idx + 1, len(T.shape))], ) return cute.make_tensor(T.iterator, cute.make_layout(shape_unpacked, stride=stride_unpacked)) class PackGQA: def __init__( self, m_block_size: cutlass.Constexpr[int], head_dim_padded: cutlass.Constexpr[int], check_hdim_oob: cutlass.Constexpr[bool], qhead_per_kvhead: cutlass.Constexpr[bool], ): self.m_block_size = m_block_size self.head_dim_padded = head_dim_padded self.check_hdim_oob = check_hdim_oob self.qhead_per_kvhead = qhead_per_kvhead @cute.jit def compute_ptr( self, tensor: cute.Tensor, cRows: cute.Tensor, tidx: cutlass.Int32, block: cutlass.Int32, threads_per_row: cutlass.Constexpr[int], num_threads: cutlass.Constexpr[int], ): num_ptr_per_thread = cute.ceil_div(cute.size(cRows), threads_per_row) tPrPtr = cute.make_fragment(num_ptr_per_thread, cutlass.Int64) for i in cutlass.range_constexpr(num_ptr_per_thread): row = i * num_threads + cRows[tidx % threads_per_row][0] idx = block * self.m_block_size + row m_idx = idx // self.qhead_per_kvhead h_idx = idx - m_idx * self.qhead_per_kvhead tPrPtr[i] = utils.elem_pointer(tensor, ((h_idx, m_idx),)).toint() return tPrPtr @cute.jit def load_Q( self, mQ: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) sQ: cute.Tensor, # (m_block_size, head_dim_padded) gmem_tiled_copy: cute.TiledCopy, tidx: cutlass.Int32, block: cutlass.Int32, seqlen: cutlass.Int32, ): gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) cQ = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) tQsQ = gmem_thr_copy.partition_D(sQ) tQcQ = gmem_thr_copy.partition_S(cQ) t0QcQ = gmem_thr_copy.get_slice(0).partition_S(cQ) tQpQ = utils.predicate_k(tQcQ, limit=mQ.shape[1]) tQcQ_row = tQcQ[0, None, 0] threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrQPtr = self.compute_ptr(mQ[None, 0], tQcQ_row, tidx, block, threads_per_row, num_threads) for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): q_ptr_i64 = utils.shuffle_sync( tPrQPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) q_gmem_ptr = cute.make_ptr( mQ.element_type, q_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) if ( t0QcQ[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tQcQ_row[0][0] ): mQ_cur = cute.make_tensor(q_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tQsQ.shape[0][0]) mQ_cur_copy = cute.tiled_divide(mQ_cur, (elems_per_load,)) for k in cutlass.range_constexpr(cute.size(tQsQ.shape[2])): ki = tQcQ[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, mQ_cur_copy[None, ki], tQsQ[None, m, k], pred=tQpQ[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) # We don't need to clear the sQ smem tiles since we'll only write out the valid outputs @cute.jit def store_LSE( self, mLSE: cute.Tensor, # (qhead_per_kvhead, seqlen_q) tLSErLSE: cute.Tensor, # (m_block_size, head_dim_padded) tiled_mma: cute.TiledMma, tidx: cutlass.Int32, block: cutlass.Int32, seqlen: cutlass.Int32, ): thr_mma = tiled_mma.get_slice(tidx) caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) taccOcO = thr_mma.partition_C(caccO) taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0] assert cute.size(tLSErLSE) == cute.size(taccOcO_row) threads_per_row = tiled_mma.tv_layout_C.shape[0][0] assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" assert cute.size(tLSErLSE) <= threads_per_row num_threads = tiled_mma.size tPrLSEPtr = self.compute_ptr(mLSE, taccOcO_row, tidx, block, threads_per_row, num_threads) for m in cutlass.range_constexpr(cute.size(tLSErLSE)): lse_ptr_i64 = utils.shuffle_sync( tPrLSEPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row, ) lse_gmem_ptr = cute.make_ptr( mLSE.element_type, lse_ptr_i64, cute.AddressSpace.gmem, assumed_align=4 ) row = block * self.m_block_size + taccOcO_row[m][0] # Only the thread corresponding to column 0 writes out the lse to gmem if taccOcO[0][1] == 0 and row < seqlen * self.qhead_per_kvhead: mLSE_copy = cute.make_tensor(lse_gmem_ptr, (1,)) mLSE_copy[0] = tLSErLSE[m] @cute.jit def store_O( self, mO: cute.Tensor, # ((qhead_per_kvhead, seqlen_q), headdim) tOrO: cute.Tensor, # (m_block_size, head_dim_padded) split across threads according to gmem_tiled_copy gmem_tiled_copy: cute.TiledCopy, tidx: cutlass.Int32, block: cutlass.Int32, seqlen: cutlass.Int32, ): gmem_thr_copy = gmem_tiled_copy.get_slice(tidx) cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) tOcO = gmem_thr_copy.partition_S(cO) t0OcO = gmem_thr_copy.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) tOcO_row = tOcO[0, None, 0] threads_per_row = gmem_tiled_copy.layout_tv_tiled.shape[0][0] assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE" num_threads = gmem_tiled_copy.size tPrOPtr = self.compute_ptr(mO[None, 0], tOcO_row, tidx, block, threads_per_row, num_threads) for m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): o_ptr_i64 = utils.shuffle_sync( tPrOPtr[m // threads_per_row], m % threads_per_row, width=threads_per_row ) o_gmem_ptr = cute.make_ptr( mO.element_type, o_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) if ( t0OcO[0, m, 0][0] < seqlen * self.qhead_per_kvhead - block * self.m_block_size - tOcO_row[0][0] ): mO_cur = cute.make_tensor(o_gmem_ptr, (self.head_dim_padded,)) elems_per_load = cute.size(tOrO.shape[0][0]) mO_cur_copy = cute.tiled_divide(mO_cur, (elems_per_load,)) for k in cutlass.range_constexpr(cute.size(tOrO.shape[2])): ki = tOcO[0, 0, k][1] // elems_per_load cute.copy( gmem_thr_copy, tOrO[None, m, k], mO_cur_copy[None, ki], pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) ================================================ FILE: flash_attn/cute/paged_kv.py ================================================ from typing import Type from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr from flash_attn.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor import math @dataclass class PagedKVManager(ParamsBase): mPageTable: cute.Tensor mK_paged: cute.Tensor mV_paged: cute.Tensor thread_idx: Int32 page_size_divmod: FastDivmodDivisor seqlen_k: Int32 leftpad_k: Int32 n_block_size: Int32 num_threads: cutlass.Constexpr[Int32] head_dim_padded: cutlass.Constexpr[Int32] head_dim_v_padded: cutlass.Constexpr[Int32] arch: cutlass.Constexpr[Int32] v_gmem_transposed: cutlass.Constexpr[bool] gmem_threads_per_row: cutlass.Constexpr[Int32] page_entry_per_thread: Int32 async_copy_elems: Int32 gmem_tiled_copy_KV: cute.TiledCopy gmem_thr_copy_KV: cute.TiledCopy tPrPage: cute.Tensor tPrPageOffset: cute.Tensor tKpK: cute.Tensor tVpV: cute.Tensor @staticmethod def create( mPageTable: cute.Tensor, mK_paged: cute.Tensor, mV_paged: cute.Tensor, page_size_divmod: FastDivmodDivisor, bidb: Int32, bidh: Int32, thread_idx: Int32, seqlen_k: Int32, leftpad_k: Int32, n_block_size: cutlass.Constexpr[Int32], head_dim_padded: cutlass.Constexpr[Int32], head_dim_v_padded: cutlass.Constexpr[Int32], num_threads: cutlass.Constexpr[Int32], dtype: Type[cutlass.Numeric], arch: cutlass.Constexpr[int] = 100, ): # SM100 transposes V in gmem to (dv, page_size, num_pages); # SM90 keeps V as (page_size, dv, num_pages), same layout as K. v_gmem_transposed = arch != 90 universal_copy_bits = 128 async_copy_elems = universal_copy_bits // dtype.width dtype_bytes = dtype.width // 8 gmem_k_block_size = math.gcd( head_dim_padded, head_dim_v_padded, 128 // dtype_bytes, ) assert gmem_k_block_size % async_copy_elems == 0 gmem_threads_per_row = gmem_k_block_size // async_copy_elems assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0 atom_async_copy = cute.make_copy_atom( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), dtype, num_bits_per_copy=universal_copy_bits, ) thr_layout = cute.make_ordered_layout( (num_threads // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), ) val_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) page_entry_per_thread = n_block_size // num_threads tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32) tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32) mPageTable = mPageTable[bidb, None] mK_paged = mK_paged[None, None, bidh, None] mV_paged = mV_paged[None, None, bidh, None] cK = cute.make_identity_tensor((n_block_size, head_dim_padded)) tKcK = gmem_thr_copy_KV.partition_S(cK) tKpK = utils.predicate_k(tKcK, limit=mK_paged.shape[1]) if const_expr(head_dim_padded == head_dim_v_padded): tVpV = tKpK else: cV = cute.make_identity_tensor((n_block_size, head_dim_v_padded)) tVcV = gmem_thr_copy_KV.partition_S(cV) # When V is transposed in gmem, dv is shape[0]; otherwise dv is shape[1] (same as K) tVpV = utils.predicate_k(tVcV, limit=mV_paged.shape[0 if v_gmem_transposed else 1]) return PagedKVManager( mPageTable, mK_paged, mV_paged, thread_idx, page_size_divmod, seqlen_k, leftpad_k, n_block_size, num_threads, head_dim_padded, head_dim_v_padded, arch, v_gmem_transposed, gmem_threads_per_row, page_entry_per_thread, async_copy_elems, gmem_tiled_copy_KV, gmem_thr_copy_KV, tPrPage, tPrPageOffset, tKpK, tVpV, ) @cute.jit def load_page_table(self, n_block: Int32): for i in cutlass.range(self.page_entry_per_thread, unroll=1): row = ( i * self.num_threads + (self.thread_idx % self.gmem_threads_per_row) * (self.num_threads // self.gmem_threads_per_row) + (self.thread_idx // self.gmem_threads_per_row) ) row_idx = n_block * self.n_block_size + row page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod) is_valid = ( (i + 1) * self.num_threads <= self.n_block_size or row < self.n_block_size ) and row_idx < self.seqlen_k page = self.mPageTable[page_idx] if is_valid else 0 self.tPrPage[i] = page self.tPrPageOffset[i] = page_offset @cute.jit def compute_X_ptr(self, K_or_V: str): tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64) mX = self.mK_paged if const_expr(K_or_V == "K") else self.mV_paged # K is always (page_size, d, num_pages). V matches K when not transposed, # but is (dv, page_size, num_pages) when transposed (SM100). transposed = const_expr(K_or_V == "V" and self.v_gmem_transposed) for i in cutlass.range(self.page_entry_per_thread, unroll=1): page = self.tPrPage[i] page_offset = self.tPrPageOffset[i] if const_expr(transposed): tPrXPtr[i] = utils.elem_pointer(mX, (0, page_offset, page)).toint() else: tPrXPtr[i] = utils.elem_pointer(mX, (page_offset, 0, page)).toint() return tPrXPtr @cute.jit def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): assert K_or_V in ("K", "V") tPrXPtr = self.compute_X_ptr(K_or_V) if const_expr(self.arch == 90): # SM90: sX is already stage-sliced by caller (sK[None, None, stage]). # Flatten hierarchical modes to get (n_block_size, head_dim). sX_pi = cute.group_modes(sX, 0, 1) # SM90 does NOT transpose V here (it's transposed via utils.transpose_view before MMA) else: # SM100: Finesse sX layout to be (M, N). sX_pi = cute.make_tensor( sX.iterator, cute.make_layout( (sX.shape[0][0], (sX.shape[0][1], sX.shape[2])), stride=(sX.stride[0][0], (sX.stride[0][1], sX.stride[2])), ), ) if const_expr(K_or_V == "V"): # Transpose smem V to match transposed gmem layout sX_pi = cute.make_tensor(sX_pi.iterator, cute.select(sX_pi.layout, mode=[1, 0])) head_dim = self.head_dim_v_padded if const_expr(K_or_V == "V") else self.head_dim_padded cX = cute.make_identity_tensor((self.n_block_size, head_dim)) tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi) tXcX = self.gmem_thr_copy_KV.partition_S(cX) tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX) seqlenk_row_limit = ( self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0 ) for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean) should_load.fill(row_valid) x_ptr_i64 = utils.shuffle_sync( tPrXPtr[m // self.gmem_threads_per_row], m % self.gmem_threads_per_row, width=self.gmem_threads_per_row, ) x_gmem_ptr = cute.make_ptr( self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,)) for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): ki = tXcX[0, 0, k][1] // self.async_copy_elems mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki] tXsX_k = tXsX[None, m, k] mX_paged_cur_copy_ki = cute.make_tensor( mX_paged_cur_copy_ki.iterator, tXsX_k.layout ) cute.copy( self.gmem_tiled_copy_KV, mX_paged_cur_copy_ki, tXsX_k, pred=should_load, ) ================================================ FILE: flash_attn/cute/pipeline.py ================================================ # Copyright (c) 2025, Tri Dao. # import math from typing import Optional from dataclasses import dataclass import cutlass.cute as cute from cutlass import Boolean, Int32, const_expr from cutlass.cutlass_dsl import if_generate, dsl_user_op from cutlass.pipeline import PipelineState from cutlass.pipeline import PipelineUserType from cutlass.pipeline import NamedBarrier as NamedBarrierOg from cutlass.pipeline import PipelineAsync as PipelineAsyncOg from cutlass.pipeline import PipelineTmaAsync as PipelineTmaAsyncOg from cutlass.pipeline import PipelineTmaUmma as PipelineTmaUmmaOg from cutlass.pipeline import PipelineUmmaAsync as PipelineUmmaAsyncOg from cutlass.pipeline import PipelineAsyncUmma as PipelineAsyncUmmaOg class PipelineStateSimple: """ Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. Use a single Int32 to store both the index and phase bit, then we use divmod to get the index and phase. If stages is a power of 2, divmod turns into bit twiddling. """ def __init__(self, stages: int, phase_index: Int32): # assert stages < 2**16 # self._log_stages = int(math.log2(stages)) # assert 1 << self._log_stages == stages, "Number of stages must be a power of 2." self._stages = stages self._phase_index = phase_index def clone(self) -> "PipelineStateSimple": return PipelineStateSimple(self.stages, self._phase_index) @property def stages(self) -> int: # return 1 << self._log_stages return self._stages @property def index(self) -> Int32: # return self._phase_index & 0xFFFF # return self._phase_index & ((1 << self._log_stages) - 1) if const_expr(self._stages == 1): return Int32(0) else: return self._phase_index % self._stages @property def phase(self) -> Int32: # return self._phase_index >> 16 # PTX docs say that the phase parity needs to be 0 or 1, so by right we need to # take modulo 2. But in practice just passing the phase in without modulo works fine. # return (self._phase_index >> self._log_stages) % 2 # return self._phase_index >> self._log_stages if const_expr(self._stages == 1): return self._phase_index else: return self._phase_index // self._stages def advance(self): if const_expr(self._stages == 1): self._phase_index ^= 1 else: self._phase_index += 1 # def then_body(phase_index): # # XOR the phase bit and set the index to 0 # return (phase_index & 0xFFFF0000) ^ (1 << 16) # def else_body(phase_index): # return phase_index # self._phase_index = if_generate( # (self._phase_index & 0xFFFF) == self.stages, # then_body, # else_body, # [self._phase_index], # [Int32], # ) def __extract_mlir_values__(self): phase_index = self._phase_index return [phase_index.ir_value()] def __new_from_mlir_values__(self, values): return PipelineStateSimple(self.stages, Int32(values[0])) def make_pipeline_state(type: PipelineUserType, stages: int): """ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. """ if type is PipelineUserType.Producer: # return PipelineStateSimple(stages, Int32(1 << 16)) return PipelineStateSimple(stages, Int32(stages)) elif type is PipelineUserType.Consumer: return PipelineStateSimple(stages, Int32(0)) else: assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." @dataclass(frozen=True) class NamedBarrier(NamedBarrierOg): @staticmethod def create(*args, **kwargs): obj = NamedBarrierOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen object.__setattr__(obj, "__class__", NamedBarrier) return obj @dsl_user_op def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: """ The aligned flavor of arrive is used when all threads in the CTA will execute the same instruction. See PTX documentation. """ cute.arch.barrier_arrive( barrier_id=self.barrier_id + index, number_of_threads=self.num_threads, loc=loc, ip=ip, ) @dsl_user_op def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) -> None: cute.arch.barrier( barrier_id=self.barrier_id + index, number_of_threads=self.num_threads, loc=loc, ip=ip, ) @dataclass(frozen=True) class PipelineAsync(PipelineAsyncOg): @staticmethod def create(*args, **kwargs): obj = PipelineAsyncOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen # obj.__class__ = PipelineAsync object.__setattr__(obj, "__class__", PipelineAsync) return obj @dsl_user_op def producer_acquire_w_index_phase( self, index: Int32, phase: Int32, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) @dsl_user_op def consumer_wait_w_index_phase( self, index: Int32, phase: Int32, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_wait_token is None or try_wait_token == 0, lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) @dataclass(frozen=True) class PipelineTmaAsync(PipelineTmaAsyncOg): """ Override producer_acquire to take in extra_tx_count parameter. """ @staticmethod def create(*args, **kwargs): obj = PipelineTmaAsyncOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen object.__setattr__(obj, "__class__", PipelineTmaAsync) return obj @dsl_user_op def producer_acquire( self, state: PipelineState, try_acquire_token: Optional[Boolean] = None, extra_tx_count: int = 0, *, loc=None, ip=None, ): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), loc=loc, ip=ip, ) if const_expr(extra_tx_count == 0): self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) else: tx_count = self.sync_object_full.tx_count + extra_tx_count self.sync_object_full.arrive_and_expect_tx(state.index, tx_count, loc=loc, ip=ip) @dsl_user_op def producer_acquire_w_index_phase( self, index: Int32, phase: Int32, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) @dsl_user_op def consumer_wait_w_index_phase( self, index: Int32, phase: Int32, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_wait_token is None or try_wait_token == 0, lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): """ TMA consumer release conditionally signals the empty buffer to the producer. """ if_generate( self.is_signalling_thread, lambda: self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip), ) @dataclass(frozen=True) class PipelineTmaUmma(PipelineTmaUmmaOg): """ Override producer_acquire to take in extra_tx_count parameter. """ @staticmethod def create(*args, **kwargs): obj = PipelineTmaUmmaOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen # obj.__class__ = PipelineTmaUmma object.__setattr__(obj, "__class__", PipelineTmaUmma) return obj @dsl_user_op def producer_acquire( self, state: PipelineState, try_acquire_token: Optional[Boolean] = None, extra_tx_count: int = 0, *, loc=None, ip=None, ): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), loc=loc, ip=ip, ) if const_expr(extra_tx_count == 0): if_generate( self.is_leader_cta, lambda: self.sync_object_full.arrive( state.index, self.producer_mask, loc=loc, ip=ip ), loc=loc, ip=ip, ) else: tx_count = self.sync_object_full.tx_count + extra_tx_count if_generate( self.is_leader_cta, lambda: self.sync_object_full.arrive_and_expect_tx( state.index, tx_count, loc=loc, ip=ip ), loc=loc, ip=ip, ) @dsl_user_op def producer_acquire_w_index_phase( self, index: Int32, phase: Int32, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None, ): """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) if_generate( self.is_leader_cta, lambda: self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_wait_w_index_phase( self, index: Int32, phase: Int32, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_wait_token is None or try_wait_token == 0, lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): """ UMMA consumer release buffer empty, cta_group needs to be provided. """ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) @dataclass(frozen=True) class PipelineUmmaAsync(PipelineUmmaAsyncOg): @staticmethod def create(*args, **kwargs): obj = PipelineUmmaAsyncOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen object.__setattr__(obj, "__class__", PipelineUmmaAsync) return obj @dsl_user_op def producer_acquire_w_index_phase( self, index: Int32, phase: Int32, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): """ UMMA producer commit buffer full, cta_group needs to be provided. """ self.sync_object_full.arrive(index, self.producer_mask, self.cta_group, loc=loc, ip=ip) @dsl_user_op def consumer_wait_w_index_phase( self, index: Int32, phase: Int32, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_wait_token is None or try_wait_token == 0, lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): self.sync_object_empty.arrive(index, self.consumer_mask, loc=loc, ip=ip) @dataclass(frozen=True) class PipelineAsyncUmma(PipelineAsyncUmmaOg): @staticmethod def create(*args, **kwargs): obj = PipelineAsyncUmmaOg.create(*args, **kwargs) # Can't assign to __class__ directly since the dataclass is frozen object.__setattr__(obj, "__class__", PipelineAsyncUmma) return obj @dsl_user_op def producer_acquire_w_index_phase( self, index: Int32, phase: Int32, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_acquire_token is None or try_acquire_token == 0, lambda: self.sync_object_empty.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): self.sync_object_full.arrive(index, self.producer_mask, loc=loc, ip=ip) @dsl_user_op def consumer_wait_w_index_phase( self, index: Int32, phase: Int32, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None, ): if_generate( try_wait_token is None or try_wait_token == 0, lambda: self.sync_object_full.wait(index, phase, loc=loc, ip=ip), loc=loc, ip=ip, ) @dsl_user_op def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): """ UMMA consumer release buffer empty, cta_group needs to be provided. """ self.sync_object_empty.arrive(index, self.consumer_mask, self.cta_group, loc=loc, ip=ip) ================================================ FILE: flash_attn/cute/pyproject.toml ================================================ [build-system] requires = ["setuptools>=75", "setuptools-scm>=8"] build-backend = "setuptools.build_meta" [project] name = "flash-attn-4" dynamic = ["version"] description = "Flash Attention CUTE (CUDA Template Engine) implementation" readme = "README.md" requires-python = ">=3.10" license = {text = "BSD 3-Clause License"} authors = [ {name = "Tri Dao"}, ] classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] dependencies = [ "nvidia-cutlass-dsl>=4.4.2", "torch", "einops", "typing_extensions", "apache-tvm-ffi>=0.1.5,<0.2", "torch-c-dlpack-ext", "quack-kernels>=0.3.3", ] [project.optional-dependencies] dev = [ "pytest", "ruff", ] [project.urls] Homepage = "https://github.com/Dao-AILab/flash-attention" Repository = "https://github.com/Dao-AILab/flash-attention" [tool.setuptools] packages = ["flash_attn.cute"] package-dir = {"flash_attn.cute" = "."} [tool.setuptools_scm] root = "../.." tag_regex = "^fa4-v(?P.+)$" git_describe_command = "git describe --dirty --tags --long --match 'fa4-v*'" fallback_version = "0.0.0" [tool.ruff] line-length = 100 [tool.ruff.lint] ignore = [ "E731", # do not assign a lambda expression, use a def "E741", # Do not use variables named 'I', 'O', or 'l' "F841", # local variable is assigned to but never used "D102", # Missing docstring in public methods ] ================================================ FILE: flash_attn/cute/seqlen_info.py ================================================ from typing import Optional from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass import Int32, const_expr from quack import copy_utils """ This consolidates all the info related to sequence length. This is so that we can do all the gmem reads once at the beginning of each tile, rather than having to repeat these reads to compute various things like n_block_min, n_block_max, etc. """ @dataclass(frozen=True) class SeqlenInfo: offset: Int32 offset_padded: Int32 seqlen: Int32 has_cu_seqlens: cutlass.Constexpr[bool] = False @staticmethod def create( batch_idx: Int32, seqlen_static: Int32, cu_seqlens: Optional[cute.Tensor] = None, seqused: Optional[cute.Tensor] = None, tile: cutlass.Constexpr[int] = 128, ): offset = 0 if const_expr(cu_seqlens is None) else cu_seqlens[batch_idx] offset_padded = ( 0 if const_expr(cu_seqlens is None) # Add divby so that the compiler knows the alignment when moving by offset_padded else cute.assume((offset + batch_idx * tile) // tile * tile, divby=tile) ) if const_expr(seqused is not None): seqlen = seqused[batch_idx] elif const_expr(cu_seqlens is not None): seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: seqlen = seqlen_static return SeqlenInfo(offset, offset_padded, seqlen, has_cu_seqlens=cu_seqlens is not None) def offset_batch( self, mT: cute.Tensor, batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, multiple: int = 1, ) -> cute.Tensor: """Offset a tensor by batch index. batch dim is at position `dim`, seqlen is at dim=0.""" if const_expr(not self.has_cu_seqlens): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mT) - 1 - dim) return mT[idx] else: off = multiple * (self.offset if const_expr(not padded) else self.offset_padded) offset = off if const_expr(cute.rank(mT.shape[0]) == 1) else (0, off) idx = (offset,) + (None,) * (cute.rank(mT) - 1) return cute.domain_offset(idx, mT) @dataclass(frozen=True) class SeqlenInfoQK: offset_q: Int32 offset_k: Int32 padded_offset_q: Int32 padded_offset_k: Int32 seqlen_q: Int32 seqlen_k: Int32 has_cu_seqlens_q: cutlass.Constexpr[bool] has_cu_seqlens_k: cutlass.Constexpr[bool] has_seqused_q: cutlass.Constexpr[bool] has_seqused_k: cutlass.Constexpr[bool] @staticmethod def create( batch_idx: Int32, seqlen_q_static: Int32, seqlen_k_static: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, tile_m: cutlass.Constexpr[Int32] = 128, tile_n: cutlass.Constexpr[Int32] = 128, ): offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx] padded_offset_q = ( 0 if const_expr(mCuSeqlensQ is None) else cute.assume((offset_q + batch_idx * tile_m) // tile_m * tile_m, divby=tile_m) ) padded_offset_k = ( 0 if const_expr(mCuSeqlensK is None) else cute.assume((offset_k + batch_idx * tile_n) // tile_n * tile_n, divby=tile_n) ) if const_expr(mSeqUsedQ is not None): seqlen_q = mSeqUsedQ[batch_idx] else: seqlen_q = ( seqlen_q_static if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx + 1] - offset_q ) if const_expr(mSeqUsedK is not None): seqlen_k = mSeqUsedK[batch_idx] else: seqlen_k = ( seqlen_k_static if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx + 1] - offset_k ) return SeqlenInfoQK( offset_q, offset_k, padded_offset_q, padded_offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q=mCuSeqlensQ is not None, has_cu_seqlens_k=mCuSeqlensK is not None, has_seqused_q=mSeqUsedQ is not None, has_seqused_k=mSeqUsedK is not None, ) def offset_batch_Q( self, mQ: cute.Tensor, batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, ragged: cutlass.Constexpr[bool] = False, ) -> cute.Tensor: """Seqlen must be the first dimension of mQ""" if const_expr(not ragged): if const_expr(not self.has_cu_seqlens_q): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) return mQ[idx] else: offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q offset_q = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (None, offset_q) idx = (offset_q,) + (None,) * (cute.rank(mQ) - 1) return cute.domain_offset(idx, mQ) else: if const_expr(not self.has_cu_seqlens_q): offset_q = 0 idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim) mQ = mQ[idx] else: offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q if const_expr(cute.rank(mQ.shape[0]) == 1): return copy_utils.offset_ragged_tensor( mQ, offset_q, self.seqlen_q, ragged_dim=0, ptr_shift=True ) else: # PackGQA assert cute.rank(mQ.shape[0]) == 2 # Unpack before calling offset_ragged_tensor, then pack idx = ((None, None),) + (None,) * (cute.rank(mQ) - 1) mQ = mQ[idx] mQ = copy_utils.offset_ragged_tensor( mQ, offset_q, self.seqlen_q, ragged_dim=1, ptr_shift=True ) return cute.group_modes(mQ, 0, 2) def offset_batch_K( self, mK: cute.Tensor, batch_idx: Int32, dim: int, padded: cutlass.Constexpr[bool] = False, ragged: cutlass.Constexpr[bool] = False, multiple: int = 1, ) -> cute.Tensor: """Seqlen must be the first dimension of mK""" if const_expr(not ragged): if const_expr(not self.has_cu_seqlens_k): idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) return mK[idx] else: offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k offset_k *= multiple idx = (offset_k,) + (None,) * (cute.rank(mK) - 1) return cute.domain_offset(idx, mK) else: if const_expr(not self.has_cu_seqlens_k): offset_k = 0 idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim) mK = mK[idx] else: offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k offset_k *= multiple return copy_utils.offset_ragged_tensor( mK, offset_k, self.seqlen_k, ragged_dim=0, ptr_shift=True ) @dataclass(frozen=True) class SeqlenInfoQKNewK: """Sequence length info for append-KV with left-padding and new K support. Extends SeqlenInfoQK with: - leftpad_k: left padding for K (tokens to skip at the start of the KV cache) - offset_k_new: offset into the new K tensor - seqlen_k_og: original K length (before appending new K), excluding leftpad - seqlen_k_new: length of new K to append - seqlen_k: total K length (seqlen_k_og + seqlen_k_new) - seqlen_rotary: position for rotary embedding computation """ leftpad_k: Int32 offset_q: Int32 offset_k: Int32 offset_k_new: Int32 seqlen_q: Int32 seqlen_k_og: Int32 seqlen_k_new: Int32 seqlen_k: Int32 seqlen_rotary: Int32 @staticmethod def create( batch_idx: Int32, seqlen_q_static: Int32, seqlen_k_static: Int32, shape_K_new_0: Int32, mCuSeqlensQ: Optional[cute.Tensor] = None, mCuSeqlensK: Optional[cute.Tensor] = None, mCuSeqlensKNew: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, mLeftpadK: Optional[cute.Tensor] = None, mSeqlensRotary: Optional[cute.Tensor] = None, ): leftpad_k = 0 if const_expr(mLeftpadK is None) else mLeftpadK[batch_idx] offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx] if const_expr(mCuSeqlensK is not None): offset_k = mCuSeqlensK[batch_idx] + leftpad_k else: offset_k = leftpad_k if const_expr(mCuSeqlensQ is not None) else 0 offset_k_new = 0 if const_expr(mCuSeqlensKNew is None) else mCuSeqlensKNew[batch_idx] # seqlen_q if const_expr(mSeqUsedQ is not None): seqlen_q = mSeqUsedQ[batch_idx] elif const_expr(mCuSeqlensQ is not None): seqlen_q = mCuSeqlensQ[batch_idx + 1] - mCuSeqlensQ[batch_idx] else: seqlen_q = seqlen_q_static # seqlen_k_og: original K length (excluding leftpad) if const_expr(mSeqUsedK is not None): seqlen_k_og = mSeqUsedK[batch_idx] - leftpad_k elif const_expr(mCuSeqlensK is not None): seqlen_k_og = mCuSeqlensK[batch_idx + 1] - mCuSeqlensK[batch_idx] - leftpad_k else: seqlen_k_og = ( seqlen_k_static - leftpad_k if const_expr(mCuSeqlensQ is not None) else seqlen_k_static ) # seqlen_k_new if const_expr(mCuSeqlensKNew is None): seqlen_k_new = 0 if const_expr(mCuSeqlensQ is None) else shape_K_new_0 else: seqlen_k_new = mCuSeqlensKNew[batch_idx + 1] - mCuSeqlensKNew[batch_idx] seqlen_k = seqlen_k_og if const_expr(mCuSeqlensQ is None) else seqlen_k_og + seqlen_k_new # seqlen_rotary: defaults to seqlen_k_og + leftpad_k unless explicitly provided if const_expr(mSeqlensRotary is not None): seqlen_rotary = mSeqlensRotary[batch_idx] else: seqlen_rotary = seqlen_k_og + leftpad_k return SeqlenInfoQKNewK( leftpad_k, offset_q, offset_k, offset_k_new, seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary, ) ================================================ FILE: flash_attn/cute/sm90_config_search.py ================================================ """Search feasible SM90 fwd/bwd attention configs for given (head_dim, head_dim_v). Enumerates tile sizes, swap modes, atom layouts, and staging options. Checks GMMA divisibility, register budget, and shared memory budget. Usage: python flash_attn/cute/sm90_config_search.py --headdim 128 python flash_attn/cute/sm90_config_search.py --mode fwd --headdim 192-128 python flash_attn/cute/sm90_config_search.py --mode bwd --headdim 192 --tile-n 64,96 """ import math # H100 hardware limits SMEM_LIMIT = 224 * 1024 # 228 KB minus ~3 KB for LSE, dPsum, mbarriers REG_LIMITS = {2: 216, 3: 128} # per-WG budget: 2WG=240-24, 3WG=160-32 THREADS_PER_WG = 128 def _divisors(n): return [d for d in range(1, n + 1) if n % d == 0] def _acc_regs(M, N, num_wg): """Accumulator registers per thread per WG.""" return M * N // (num_wg * THREADS_PER_WG) def _check_mma(M, N, num_wg, atom_layout_m, swap_AB): """Check MMA feasibility. Returns regs per WG, or None if infeasible. GMMA atom M=64. Swap exchanges (M, N) and atom layout. Requires: M divisible by (atom_layout_m * 64), N by (atom_layout_n * 8). """ if swap_AB: M, N = N, M atom_layout_m = num_wg // atom_layout_m atom_layout_n = num_wg // atom_layout_m if M % (atom_layout_m * 64) != 0 or N % (atom_layout_n * 8) != 0: return None return _acc_regs(M, N, num_wg) def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False): """Total SMEM read traffic for one MMA (all WGs combined). num_instr = (M_eff / 64) * wg_n instructions total. Each reads A(64, K_red) and B(N_eff/wg_n, K_red) from smem (bf16). """ num_instr = (M_eff // 64) * wg_n A_per = 64 * K_red * 2 if not is_rs else 0 B_per = (N_eff // wg_n) * K_red * 2 return num_instr * (A_per + B_per) # ============================================================================ # Backward # ============================================================================ def _check_bwd_config( hdim, hdimv, tile_m, tile_n, num_wg, SdP_swapAB, dKV_swapAB, dQ_swapAB, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, ): reg_limit = REG_LIMITS[num_wg] # MMA feasibility regs_SdP = _check_mma(tile_m, tile_n, num_wg, AtomLayoutMSdP, SdP_swapAB) regs_dK = _check_mma(tile_n, hdim, num_wg, AtomLayoutNdKV, dKV_swapAB) regs_dV = _check_mma(tile_n, hdimv, num_wg, AtomLayoutNdKV, dKV_swapAB) regs_dQ = _check_mma(tile_m, hdim, num_wg, AtomLayoutMdQ, dQ_swapAB) if any(r is None for r in (regs_SdP, regs_dK, regs_dV, regs_dQ)): return None # Peak regs: max(S+dP, dQ) + dK + dV total_regs = max(2 * regs_SdP, regs_dQ) + regs_dK + regs_dV if total_regs > reg_limit: return None # SMEM mma_dkv_is_rs = ( AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_wg and SdP_swapAB and not dKV_swapAB ) Q_stage, PdS_stage = 2, 1 for dO_stage in (2, 1): sQ = tile_m * hdim * 2 * Q_stage sK = tile_n * hdim * 2 sV = tile_n * hdimv * 2 sdO = tile_m * hdimv * 2 * dO_stage sPdS = tile_m * tile_n * 2 * PdS_stage sP = sPdS if not mma_dkv_is_rs else 0 sdQaccum = tile_m * hdim * 4 smem = sQ + sK + sV + sdO + sP + sPdS + sdQaccum if smem <= SMEM_LIMIT: break else: return None # SMEM traffic def _swap(a, b, s): return (b, a) if s else (a, b) def _wg_n(al_m, s): return al_m if s else num_wg // al_m M_s, N_s = _swap(tile_m, tile_n, SdP_swapAB) wn_SdP = _wg_n(AtomLayoutMSdP, SdP_swapAB) traffic_S = _mma_traffic(M_s, N_s, hdim, num_wg, wn_SdP) traffic_dP = _mma_traffic(M_s, N_s, hdimv, num_wg, wn_SdP) wn_dKV = _wg_n(AtomLayoutNdKV, dKV_swapAB) M_dv, N_dv = _swap(tile_n, hdimv, dKV_swapAB) traffic_dV = _mma_traffic(M_dv, N_dv, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) M_dk, N_dk = _swap(tile_n, hdim, dKV_swapAB) traffic_dK = _mma_traffic(M_dk, N_dk, tile_m, num_wg, wn_dKV, is_rs=mma_dkv_is_rs) M_dq, N_dq = _swap(tile_m, hdim, dQ_swapAB) wn_dQ = _wg_n(AtomLayoutMdQ, dQ_swapAB) traffic_dQ = _mma_traffic(M_dq, N_dq, tile_n, num_wg, wn_dQ) traffic_P_store = tile_m * tile_n * 2 if not mma_dkv_is_rs else 0 traffic_dS_store = tile_m * tile_n * 2 traffic_dQ_smem = tile_m * hdim * 4 * 2 # store + TMA load smem_traffic = ( traffic_S + traffic_dP + traffic_dV + traffic_dK + traffic_dQ + traffic_P_store + traffic_dS_store + traffic_dQ_smem ) return dict( tile_m=tile_m, tile_n=tile_n, num_wg=num_wg, Q_stage=Q_stage, dO_stage=dO_stage, PdS_stage=PdS_stage, SdP_swapAB=SdP_swapAB, dKV_swapAB=dKV_swapAB, dQ_swapAB=dQ_swapAB, AtomLayoutMSdP=AtomLayoutMSdP, AtomLayoutNdKV=AtomLayoutNdKV, AtomLayoutMdQ=AtomLayoutMdQ, mma_dkv_is_rs=mma_dkv_is_rs, regs_SdP=regs_SdP, regs_dK=regs_dK, regs_dV=regs_dV, regs_dQ=regs_dQ, total_regs=total_regs, reg_limit=reg_limit, smem_bytes=smem, smem_kb=smem / 1024, smem_traffic=smem_traffic, smem_traffic_kb=smem_traffic / 1024, smem_traffic_per_block=smem_traffic / (tile_m * tile_n), ) def find_feasible_bwd_configs( head_dim, head_dim_v=None, tile_m_choices=(64, 80, 96, 112, 128), tile_n_choices=(64, 80, 96, 112, 128), ): if head_dim_v is None: head_dim_v = head_dim hdim = int(math.ceil(head_dim / 32) * 32) hdimv = int(math.ceil(head_dim_v / 32) * 32) results = [] for num_wg in (2, 3): divs = _divisors(num_wg) for tile_m in tile_m_choices: for tile_n in tile_n_choices: for SdP_swap in (False, True): if (tile_n if SdP_swap else tile_m) % 64 != 0: continue for dKV_swap in (False, True): if not dKV_swap and tile_n % 64 != 0: continue if dKV_swap and (hdim % 64 != 0 or hdimv % 64 != 0): continue for dQ_swap in (False, True): if (hdim if dQ_swap else tile_m) % 64 != 0: continue for a1 in divs: for a2 in divs: for a3 in divs: cfg = _check_bwd_config( hdim, hdimv, tile_m, tile_n, num_wg, SdP_swap, dKV_swap, dQ_swap, a1, a2, a3, ) if cfg is not None: results.append(cfg) results.sort(key=lambda c: (-c["tile_n"], -c["tile_m"], c["smem_traffic_per_block"])) return results def print_bwd_configs(configs, max_results=20): if not configs: print("No feasible configs found!") return n = min(len(configs), max_results) print(f"Found {len(configs)} feasible configs (showing top {n}):\n") hdr = ( f"{'wg':>2} {'tm':>3} {'tn':>3} " f"{'SdP':>3} {'dKV':>3} {'dQ':>3} " f"{'aSdP':>4} {'adKV':>4} {'adQ':>4} " f"{'Qs':>2} {'dOs':>3} " f"{'rS':>3} {'rdK':>3} {'rdV':>3} {'rdQ':>3} {'tot':>4}/{'':<3} " f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" ) print(hdr) print("-" * len(hdr)) B = lambda b: "T" if b else "F" for c in configs[:max_results]: print( f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " f"{B(c['SdP_swapAB']):>3} {B(c['dKV_swapAB']):>3} {B(c['dQ_swapAB']):>3} " f"{c['AtomLayoutMSdP']:>4} {c['AtomLayoutNdKV']:>4} {c['AtomLayoutMdQ']:>4} " f"{c['Q_stage']:>2} {c['dO_stage']:>3} " f"{c['regs_SdP']:>3} {c['regs_dK']:>3} {c['regs_dV']:>3} {c['regs_dQ']:>3} " f"{c['total_regs']:>4}/{c['reg_limit']:<3} " f"{c['smem_kb']:>4.0f}K " f"{c['smem_traffic_kb']:>6.0f}K " f"{c['smem_traffic_per_block']:>6.1f}" ) # ============================================================================ # Forward # ============================================================================ def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg): reg_limit = REG_LIMITS[num_wg] tile_m = num_wg * 64 if tile_n % 8 != 0: return None regs_S = _acc_regs(tile_m, tile_n, num_wg) regs_O = _acc_regs(tile_m, hdimv, num_wg) regs_P = regs_S // 2 # bf16 = half of f32 if overlap_wg: total_regs = regs_S + regs_P + regs_O else: total_regs = regs_S + regs_O if total_regs > reg_limit: return None # SMEM: 1 stage Q, 2 stages K/V, O overlaps Q, sP if not RS sQ = tile_m * hdim * 2 sK = tile_n * hdim * 2 * 2 sV = tile_n * hdimv * 2 * 2 sO = tile_m * hdimv * 2 sP = tile_m * tile_n * 2 if not pv_is_rs else 0 smem = max(sQ, sO) + sK + sV + sP if smem > SMEM_LIMIT: return None # SMEM traffic: num_instr = num_wg (all WGs in M, wg_n=1) traffic_S = num_wg * (64 * hdim * 2 + tile_n * hdim * 2) A_pv = 64 * tile_n * 2 if not pv_is_rs else 0 traffic_O = num_wg * (A_pv + hdimv * tile_n * 2) traffic_P_store = tile_m * tile_n * 2 if not pv_is_rs else 0 smem_traffic = traffic_S + traffic_O + traffic_P_store return dict( tile_m=tile_m, tile_n=tile_n, num_wg=num_wg, pv_is_rs=pv_is_rs, overlap_wg=overlap_wg, regs_S=regs_S, regs_O=regs_O, regs_P=regs_P, total_regs=total_regs, reg_limit=reg_limit, smem_bytes=smem, smem_kb=smem / 1024, smem_traffic=smem_traffic, smem_traffic_kb=smem_traffic / 1024, smem_traffic_per_block=smem_traffic / (tile_m * tile_n), ) def find_feasible_fwd_configs( head_dim, head_dim_v=None, tile_n_choices=(64, 80, 96, 112, 128, 144, 160, 176, 192) ): if head_dim_v is None: head_dim_v = head_dim hdim = int(math.ceil(head_dim / 32) * 32) hdimv = int(math.ceil(head_dim_v / 32) * 32) results = [] for num_wg in (2, 3): for tile_n in tile_n_choices: for pv_is_rs in (True, False): for overlap_wg in (True, False): cfg = _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg) if cfg is not None: results.append(cfg) results.sort(key=lambda c: (-c["tile_n"], c["smem_traffic_per_block"])) return results def print_fwd_configs(configs, max_results=20): if not configs: print("No feasible configs found!") return n = min(len(configs), max_results) print(f"Found {len(configs)} feasible configs (showing top {n}):\n") hdr = ( f"{'wg':>2} {'tm':>3} {'tn':>3} " f"{'RS':>2} {'olap':>4} " f"{'rS':>3} {'rP':>3} {'rO':>3} {'tot':>4}/{'':<3} " f"{'smem':>5} {'traffic':>7} {'tr/blk':>6}" ) print(hdr) print("-" * len(hdr)) B = lambda b: "T" if b else "F" for c in configs[:max_results]: print( f"{c['num_wg']:>2} {c['tile_m']:>3} {c['tile_n']:>3} " f"{B(c['pv_is_rs']):>2} {B(c['overlap_wg']):>4} " f"{c['regs_S']:>3} {c['regs_P']:>3} {c['regs_O']:>3} " f"{c['total_regs']:>4}/{c['reg_limit']:<3} " f"{c['smem_kb']:>4.0f}K " f"{c['smem_traffic_kb']:>6.0f}K " f"{c['smem_traffic_per_block']:>6.1f}" ) # ============================================================================ # CLI # ============================================================================ if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Search feasible SM90 MMA configs") parser.add_argument("--mode", choices=["fwd", "bwd", "both"], default="both") parser.add_argument( "--headdim", type=str, default="128", help="Head dim, or hdim-hdimv (e.g. 192-128)" ) parser.add_argument("--tile-m", type=str, default="64,80,96,112,128", help="Bwd tile_m choices") parser.add_argument( "--tile-n", type=str, default=None, help="tile_n choices (default: fwd up to 192, bwd up to 128)", ) parser.add_argument("-n", "--num-results", type=int, default=30) args = parser.parse_args() parts = args.headdim.split("-") hdim = int(parts[0]) hdimv = int(parts[1]) if len(parts) > 1 else hdim TN_FWD = "64,80,96,112,128,144,160,176,192" TN_BWD = "64,80,96,112,128" if args.mode in ("fwd", "both"): tn = tuple(int(x) for x in (args.tile_n or TN_FWD).split(",")) print(f"=== FWD configs: hdim={hdim}, hdimv={hdimv} ===\n") print_fwd_configs(find_feasible_fwd_configs(hdim, hdimv, tn), args.num_results) print() if args.mode in ("bwd", "both"): tm = tuple(int(x) for x in args.tile_m.split(",")) tn = tuple(int(x) for x in (args.tile_n or TN_BWD).split(",")) print(f"=== BWD configs: hdim={hdim}, hdimv={hdimv} ===\n") print_bwd_configs(find_feasible_bwd_configs(hdim, hdimv, tm, tn), args.num_results) ================================================ FILE: flash_attn/cute/softmax.py ================================================ # Copyright (c) 2025, Tri Dao. import math import operator from typing import Tuple from dataclasses import dataclass import cutlass import cutlass.cute as cute from cutlass import Float32 from quack import layout_utils import flash_attn.cute.utils as utils from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass class Softmax(ParamsBase): scale_log2: Float32 num_rows: cutlass.Constexpr[int] row_max: cute.Tensor row_sum: cute.Tensor arch: cutlass.Constexpr[int] = 80 softmax_scale: Float32 | None = None @staticmethod def create( scale_log2: Float32, num_rows: cutlass.Constexpr[int], arch: cutlass.Constexpr[int] = 80, softmax_scale: Float32 | None = None, ): row_max = cute.make_rmem_tensor(num_rows, Float32) row_sum = cute.make_rmem_tensor(num_rows, Float32) return Softmax(scale_log2, num_rows, row_max, row_sum, arch, softmax_scale) def reset(self) -> None: self.row_max.fill(-Float32.inf) self.row_sum.fill(0.0) def _compute_row_max( self, acc_S_row: cute.TensorSSA, init_val: float | Float32 | None = None ) -> Float32: return utils.fmax_reduce(acc_S_row, init_val, arch=self.arch) def _compute_row_sum( self, acc_S_row_exp: cute.TensorSSA, init_val: float | Float32 | None = None ) -> Float32: return utils.fadd_reduce(acc_S_row_exp, init_val, arch=self.arch) @cute.jit def online_softmax( self, acc_S: cute.Tensor, is_first: cutlass.Constexpr[bool] = False, check_inf: cutlass.Constexpr[bool] = True, ) -> cute.Tensor: """Apply online softmax and return the row_scale to rescale O. :param acc_S: acc_S tensor :type acc_S: cute.Tensor :param is_first: is first n_block :type is_first: cutlass.Constexpr """ # Change acc_S to M,N layout view. acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) row_scale = cute.make_fragment_like(self.row_max, Float32) row_max = self.row_max row_sum = self.row_sum scale_log2 = self.scale_log2 arch = self.arch # Each iteration processes one row of acc_S for r in cutlass.range(cute.size(row_max), unroll_full=True): acc_S_row = acc_S_mn[r, None].load() # (n_block_size) row_max_cur = utils.fmax_reduce( acc_S_row, init_val=row_max[r] if cutlass.const_expr(not is_first) else None, arch=arch, ) row_max_cur = cute.arch.warp_reduction_max(row_max_cur, threads_in_group=4) # Update row_max before changing row_max_cur to safe value for -inf row_max_prev = row_max[r] row_max[r] = row_max_cur if cutlass.const_expr(check_inf): row_max_cur = 0.0 if row_max_cur == -Float32.inf else row_max_cur if cutlass.const_expr(is_first): row_max_cur_scaled = row_max_cur * scale_log2 acc_S_row_exp = cute.math.exp2( acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True ) acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch) row_scale[r] = 1.0 else: row_max_cur_scaled = row_max_cur * scale_log2 acc_S_row_exp = cute.math.exp2( acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True ) # row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled) row_scale[r] = cute.math.exp2( (row_max_prev - row_max_cur) * scale_log2, fastmath=True ) acc_S_row_sum = utils.fadd_reduce( acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch ) row_sum[r] = acc_S_row_sum acc_S_mn[r, None].store(acc_S_row_exp) return row_scale @cute.jit def finalize( self, final_scale: Float32 = 1.0, sink_val: Float32 | cute.Tensor | None = None ) -> cute.Tensor: """Finalize the online softmax by computing the scale and logsumexp.""" if cutlass.const_expr(sink_val is not None and isinstance(sink_val, cute.Tensor)): assert cute.size(sink_val) == cute.size(self.row_sum) row_sum = self.row_sum row_max = self.row_max scale_log2 = self.scale_log2 # quad reduction for row_sum as we didn't do it during each iteration of online softmax row_sum.store(utils.warp_reduce(row_sum.load(), operator.add, width=4)) row_scale = cute.make_fragment_like(row_max, Float32) for r in cutlass.range(cute.size(row_sum), unroll_full=True): if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) row_sum[r] += cute.math.exp2( sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True ) # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] row_scale[r] = ( cute.arch.rcp_approx(row_sum[r] if not acc_O_mn_row_is_zero_or_nan else 1.0) ) * final_scale row_sum_cur = row_sum[r] LN2 = math.log(2.0) row_sum[r] = ( (row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) return row_scale @cute.jit def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: """Scale each row of acc_O by the given scale tensor. :param acc_O: input tensor :type acc_O: cute.Tensor :param row_scale: row_scale tensor :type row_scale: cute.Tensor """ acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O) assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0]) for r in cutlass.range(cute.size(row_scale), unroll_full=True): acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r]) @dataclass class SoftmaxSm100(Softmax): rescale_threshold: cutlass.Constexpr[float] = 0.0 @staticmethod def create( scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None, ): num_rows = 1 arch = 100 row_max = cute.make_rmem_tensor(num_rows, Float32) row_sum = cute.make_rmem_tensor(num_rows, Float32) return SoftmaxSm100( scale_log2, num_rows, row_max, row_sum, arch, softmax_scale, rescale_threshold=rescale_threshold, ) @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): row_max_new = self._compute_row_max(acc_S_row) row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale = 0.0 else: row_max_old = self.row_max[0] row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 acc_scale = cute.math.exp2(acc_scale_, fastmath=True) if cutlass.const_expr(self.rescale_threshold > 0.0): if acc_scale_ >= -self.rescale_threshold: row_max_new = row_max_old row_max_safe = row_max_old acc_scale = 1.0 self.row_max[0] = row_max_new return row_max_safe, acc_scale def update_row_sum( self, acc_S_row_exp: cute.TensorSSA, row_scale: Float32, is_first: int = False ) -> None: init_val = self.row_sum[0] * row_scale if cutlass.const_expr(not is_first) else None # self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=self.row_sum[0] * row_scale) self.row_sum[0] = self._compute_row_sum(acc_S_row_exp, init_val=init_val) # tmp = self._compute_row_sum(acc_S_row_exp) # self.row_sum[0] = self.row_sum[0] * row_scale + tmp @cute.jit def scale_subtract_rowmax( self, acc_S_row: cute.Tensor, row_max: Float32, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (-row_max_scaled, -row_max_scaled), ) @cute.jit def apply_exp2_convert( self, acc_S_row: cute.Tensor, acc_S_row_converted: cute.Tensor, ex2_emu_freq: cutlass.Constexpr[int] = 0, ex2_emu_res: cutlass.Constexpr[int] = 4, ex2_emu_start_frg: cutlass.Constexpr[int] = 0, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" frg_tile = 32 assert frg_tile % 2 == 0 frg_cnt = cute.size(acc_S_row) // frg_tile assert cute.size(acc_S_row) % frg_tile == 0 acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) acc_S_row_converted_frg = cute.logical_divide( acc_S_row_converted, cute.make_layout(frg_tile) ) for j in cutlass.range_constexpr(frg_cnt): for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) if cutlass.const_expr(ex2_emu_freq == 0): acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) else: if cutlass.const_expr( k % ex2_emu_freq < ex2_emu_freq - ex2_emu_res or j >= frg_cnt - 1 or j < ex2_emu_start_frg ): acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) acc_S_row_frg[k + 1, j] = cute.math.exp2( acc_S_row_frg[k + 1, j], fastmath=True ) else: # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]) acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2( acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] ) acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) @cute.jit def scale_apply_exp2_convert( self, acc_S_row: cute.Tensor, row_max: Float32, acc_S_row_converted: cute.Tensor, ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" minus_row_max_scaled = -row_max * self.scale_log2 for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), (minus_row_max_scaled, minus_row_max_scaled), ) # for i in cutlass.range_constexpr(0, cute.size(acc_S_row.shape), 2): # acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( # (acc_S_row[i], acc_S_row[i + 1]), # (self.scale_log2, self.scale_log2), # (minus_row_max_scaled, minus_row_max_scaled), # ) # acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True) # acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True) frg_tile = 32 assert frg_tile % 2 == 0 frg_cnt = cute.size(acc_S_row) // frg_tile assert cute.size(acc_S_row) % frg_tile == 0 acc_S_row_frg = cute.logical_divide(acc_S_row, cute.make_layout(frg_tile)) acc_S_row_converted_frg = cute.logical_divide( acc_S_row_converted, cute.make_layout(frg_tile) ) for j in cutlass.range_constexpr(frg_cnt): for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2): # acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = ( # cute.arch.fma_packed_f32x2( # (acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j]), # (self.scale_log2, self.scale_log2), # (minus_row_max_scaled, minus_row_max_scaled), # ) # ) # acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) # acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True) acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True) acc_S_row_converted_frg[None, j].store( acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type) ) @cute.jit def floor_if_packed( q_idx, qhead_per_kvhead: cutlass.Constexpr[int], ) -> cute.Tensor: """Convert q_idx to packed format for Pack-GQA.""" if cutlass.const_expr(qhead_per_kvhead == 1): return q_idx return q_idx // qhead_per_kvhead @cute.jit def apply_score_mod_inner( score_tensor, index_tensor, score_mod: cutlass.Constexpr, batch_idx, head_idx, softmax_scale, vec_size: cutlass.Constexpr, qk_acc_dtype: cutlass.Constexpr, aux_tensors, fastdiv_mods, seqlen_info: SeqlenInfoQK, constant_q_idx: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, transpose_indices: cutlass.Constexpr[bool] = False, ): """Shared implementation for applying score modification. Args: score_tensor: The scores to modify (acc_S for flash_fwd, tSrS_t2r for sm100) index_tensor: Index positions (tScS for flash_fwd, tScS_t2r for sm100) score_mod: The score modification function to apply batch_idx: Batch index head_idx: Head index softmax_scale: Scale to apply vec_size: Vector size for processing elements qk_acc_dtype: Data type for accumulator aux_tensors: Optional aux_tensors for FlexAttention fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping seqlen_info: Sequence length info constant_q_idx: If provided, use this constant for all q_idx values If None, compute q_idx per-element qhead_per_kvhead_packgqa: Pack-GQA replication factor. Divide q_idx by this when greater than 1 so score mods see logical heads. transpose_indices: If True, swap q_idx/kv_idx in index_tensor (for bwd kernel where S is transposed) """ # Index positions in the index_tensor tuple # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx if cutlass.const_expr(transpose_indices): q_idx_pos = cutlass.const_expr(1) kv_idx_pos = cutlass.const_expr(0) else: q_idx_pos = cutlass.const_expr(0) kv_idx_pos = cutlass.const_expr(1) n_vals = cutlass.const_expr(cute.size(score_tensor.shape)) score_vec = cute.make_rmem_tensor(vec_size, qk_acc_dtype) kv_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # SSA values for batch (constant across all elements) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) # Handle q_idx based on whether it's constant q_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) # For Pack-GQA with non-constant q_idx, we need per-element head indices # since a thread my process multiple query head indices if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_vec = cute.make_rmem_tensor(vec_size, cutlass.Int32) for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): for j in cutlass.range(vec_size, unroll_full=True): score_vec[j] = score_tensor[i + j] * softmax_scale # Extract head offset from packed q_idx for Pack-GQA if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): q_idx_packed = index_tensor[i + j][q_idx_pos] # Building up the logical q_head idx: final_q_head = kv_head * qhead_per_kvhead + (q_physical % qhead_per_kvhead) q_idx_logical = q_idx_packed // qhead_per_kvhead head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset # If we will do loads we mod, in order to not read OOB if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed( index_tensor[i + j][q_idx_pos], qhead_per_kvhead ) _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) q_idx_vec[j] = q_idx_wrapped else: _, seqlen_k_divmod = fastdiv_mods _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod) kv_idx_vec[j] = kv_idx_wrapped else: # No bounds checking - direct indexing if constant_q_idx is None: q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead) kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos] # Convert to SSA for score_mod call score_ssa = score_vec.load() kv_idx_ssa = kv_idx_vec.load() if cutlass.const_expr(constant_q_idx is None): q_idx_ssa = q_idx_vec.load() else: # NB we do not apply Pack-GQA division here, as constant_q_idx is assumed to already be logical q_idx_const = constant_q_idx q_idx_ssa = utils.scalar_to_ssa(q_idx_const, cutlass.Int32).broadcast_to((vec_size,)) # Compute head_idx_ssa: per-element for Pack-GQA with non-constant q_idx, constant otherwise if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_ssa = head_idx_vec.load() else: head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) aux_args = [] if cutlass.const_expr(aux_tensors is not None): aux_args = aux_tensors post_mod_scores = score_mod( score_ssa, batch_idx_ssa, head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, seqlen_info=seqlen_info, aux_tensors=aux_args, ) # Write back modified scores score_vec.store(post_mod_scores) for j in cutlass.range(vec_size, unroll_full=True): score_tensor[i + j] = score_vec[j] @cute.jit def apply_score_mod_bwd_inner( grad_tensor, score_tensor, index_tensor, score_mod_bwd: cutlass.Constexpr, batch_idx, head_idx, softmax_scale, vec_size: cutlass.Constexpr, qk_acc_dtype: cutlass.Constexpr, aux_tensors, fastdiv_mods, seqlen_info, constant_q_idx: cutlass.Constexpr, qhead_per_kvhead: cutlass.Constexpr[int] = 1, transpose_indices: cutlass.Constexpr[bool] = False, ): """Apply backward score modification (joint graph). Args: grad_tensor: in/out: dlogits rewritten in-place with d(scaled_scores) score_tensor: pre-mod scores (unscaled QK tile), scaled by softmax_scale internally index_tensor: Index positions (same as forward) score_mod_bwd: The backward score modification function (joint graph) batch_idx: Batch index head_idx: Head index softmax_scale: Scale to apply to score_tensor vec_size: Vector size for processing elements qk_acc_dtype: Data type for accumulator aux_tensors: Optional aux_tensors for FlexAttention fastdiv_mods: Tuple of (seqlen_q_divmod, seqlen_k_divmod) for wrapping seqlen_info: Sequence length info constant_q_idx: If provided, use this constant for all q_idx values qhead_per_kvhead: Pack-GQA replication factor transpose_indices: If True, swap q_idx/kv_idx in index_tensor """ # Index positions in the index_tensor tuple # Forward: index_tensor[...][0] = q_idx, index_tensor[...][1] = kv_idx # Backward (transposed): index_tensor[...][0] = kv_idx, index_tensor[...][1] = q_idx if cutlass.const_expr(transpose_indices): q_idx_pos = cutlass.const_expr(1) kv_idx_pos = cutlass.const_expr(0) else: q_idx_pos = cutlass.const_expr(0) kv_idx_pos = cutlass.const_expr(1) n_vals = cutlass.const_expr(cute.size(grad_tensor.shape)) grad_vec = cute.make_fragment(vec_size, qk_acc_dtype) score_vec = cute.make_fragment(vec_size, qk_acc_dtype) kv_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32).broadcast_to((vec_size,)) q_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) # For Pack-GQA with non-constant q_idx, we need per-element head indices if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_vec = cute.make_fragment(vec_size, cutlass.Int32) for i in cutlass.range(0, n_vals, vec_size, unroll_full=True): for j in cutlass.range(vec_size, unroll_full=True): grad_vec[j] = grad_tensor[i + j] # Scale score so joint graph sees same value as forward score_mod score_vec[j] = score_tensor[i + j] * softmax_scale if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): q_idx_packed = index_tensor[i + j][q_idx_pos] q_idx_logical = q_idx_packed // qhead_per_kvhead head_offset = q_idx_packed - q_idx_logical * qhead_per_kvhead head_idx_vec[j] = head_idx * qhead_per_kvhead + head_offset if cutlass.const_expr(aux_tensors is not None and fastdiv_mods is not None): if cutlass.const_expr(constant_q_idx is None): seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods q_idx_floored = floor_if_packed( index_tensor[i + j][q_idx_pos], qhead_per_kvhead ) _, q_idx_wrapped = divmod(q_idx_floored, seqlen_q_divmod) q_idx_vec[j] = q_idx_wrapped else: _, seqlen_k_divmod = fastdiv_mods _, kv_idx_wrapped = divmod(index_tensor[i + j][kv_idx_pos], seqlen_k_divmod) kv_idx_vec[j] = kv_idx_wrapped else: # No bounds checking - direct indexing if constant_q_idx is None: q_idx_vec[j] = floor_if_packed(index_tensor[i + j][q_idx_pos], qhead_per_kvhead) kv_idx_vec[j] = index_tensor[i + j][kv_idx_pos] grad_ssa = grad_vec.load() score_ssa = score_vec.load() kv_idx_ssa = kv_idx_vec.load() if cutlass.const_expr(constant_q_idx is None): q_idx_ssa = q_idx_vec.load() else: q_idx_ssa = utils.scalar_to_ssa(constant_q_idx, cutlass.Int32).broadcast_to((vec_size,)) if cutlass.const_expr(qhead_per_kvhead > 1 and constant_q_idx is None): head_idx_ssa = head_idx_vec.load() else: head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32).broadcast_to((vec_size,)) aux_args = [] if cutlass.const_expr(aux_tensors is not None): aux_args = aux_tensors grad_out_ssa = score_mod_bwd( grad_ssa, score_ssa, batch_idx_ssa, head_idx_ssa, q_idx=q_idx_ssa, kv_idx=kv_idx_ssa, seqlen_info=seqlen_info, aux_tensors=aux_args, ) grad_vec.store(grad_out_ssa) for j in cutlass.range(vec_size, unroll_full=True): grad_tensor[i + j] = grad_vec[j] ================================================ FILE: flash_attn/cute/testing.py ================================================ import math from contextlib import nullcontext from functools import wraps from typing import Optional import torch import torch.nn.functional as F from einops import rearrange, repeat from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices): ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() return torch.gather( rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim), ).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output): (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype, ) grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None index_first_axis = IndexFirstAxis.apply class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim): ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros( first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype ) output[indices] = values return output @staticmethod def backward(ctx, grad_output): (indices,) = ctx.saved_tensors grad_values = grad_output[indices] return grad_values, None, None index_put_first_axis = IndexPutFirstAxis.apply def unpad_input(hidden_states, attention_mask, unused_mask=None): all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) in_fake_mode = active_fake_mode() is not None if not in_fake_mode: indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() else: # torch.nonzero and .item() are not supported in FakeTensorMode batch_size, seqlen = attention_mask.shape indices = torch.arange(batch_size * seqlen, device=hidden_states.device) max_seqlen_in_batch = seqlen cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch, ) def pad_input(hidden_states, indices, batch, seqlen): output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device, ) else: lengths = torch.randint( max(0 if zero_lengths else 1, max_seqlen // 3), max_seqlen + 1, (batch_size, 1), device=device, ) if zero_lengths: for i in range(batch_size): if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( k, key_padding_mask, key_unused_mask ) v_unpad, *_ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(None, None), sink_token_length=0, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] is None: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk if window_size[1] is None: local_mask_left = col_idx > sk else: local_mask_left = col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk) return torch.logical_or( local_mask_left, torch.logical_and( col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length ), ) def construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk return torch.logical_or( col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(None, None), attention_chunk=0, sink_token_length=0, learnable_sink: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, intermediate_dtype=None, ): if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None if q_descale is not None: q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) q = (q.float() * q_descale).to(q.dtype) qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] dv = v.shape[-1] softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if qv is not None: scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) local_mask = None if window_size[0] is not None or window_size[1] is not None: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, sink_token_length, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) if attention_chunk > 0: chunk_mask = construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) local_mask = ( torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask ) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: scores_fp32 = scores.to(torch.float32) logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) learnable_sink = rearrange(learnable_sink, "h -> h 1 1") logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( learnable_sink - logits_or_sinks_max ) attention = (unnormalized_scores / normalizer).to(v.dtype) if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) if key_padding_mask is not None: attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) if local_mask is not None: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention if intermediate_dtype is not None: attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) def maybe_fake_tensor_mode(fake: bool = True): """ One way to populate/pre-compile cache is to use torch fake tensor mode, which does not allocate actual GPU tensors but retains tensor shape/dtype metadata for cute.compile. """ def decorator(fn): @wraps(fn) def wrapper(*args, **kwargs): with FakeTensorMode() if fake else nullcontext(): return fn(*args, **kwargs) return wrapper return decorator def is_fake_mode() -> bool: return active_fake_mode() is not None ================================================ FILE: flash_attn/cute/tile_scheduler.py ================================================ # Copyright (c) 2025, Tri Dao. from typing import Optional, Tuple from dataclasses import dataclass try: from typing import override except ImportError: # Python < 3.12 from typing_extensions import override import cutlass from cutlass._mlir import ir import cutlass.cute as cute from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase import flash_attn.cute.utils as utils from flash_attn.cute.fast_math import clz class WorkTileInfo(cutlass.utils.WorkTileInfo): """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" @override def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": assert len(values) == 5 new_tile_idx = cutlass.new_from_mlir_values(self._tile_idx, values[:-1]) new_is_valid_tile = cutlass.new_from_mlir_values(self._is_valid_tile, [values[-1]]) return WorkTileInfo(new_tile_idx, new_is_valid_tile) @dataclass class TileSchedulerArguments(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 num_splits: Int32 seqlen_k: Int32 headdim: Int32 headdim_v: Int32 total_q: Int32 tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 element_size: cutlass.Constexpr[int] = 2 is_persistent: cutlass.Constexpr[bool] = False lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False class SingleTileScheduler: @dataclass class Params(ParamsBase): num_block: Int32 num_head: Int32 num_batch: Int32 num_splits: Int32 num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) @staticmethod def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileScheduler.Params": return SingleTileScheduler.Params( args.num_block, args.num_head, args.num_batch, args.num_splits, FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): self.params = params self._blk_coord = blk_coord self._is_first_block = True self._loc = loc self._ip = ip @staticmethod def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": # if const_expr(cute.size(params.cluster_shape_mn) == 1): # blk_coord = cute.arch.block_idx() # else: # # All CTAs in a cluster must get the same block coordinate # blk_coord = cute.arch.cluster_idx() # Temporary set to block_idx until we sort out the best way to handle cluster blk_coord = cute.arch.block_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" return ( cute.round_up(params.num_block, params.cluster_shape_mn[0]), params.num_head * params.num_splits, params.num_batch, ) def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: block_idx, head_idx, batch_idx = self._blk_coord if const_expr(self.params.is_split_kv): head_idx, split_idx = divmod(head_idx, self.params.num_splits_divmod) else: split_idx = Int32(0) return WorkTileInfo( (block_idx, head_idx, batch_idx, split_idx), self._is_first_block, ) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): self._is_first_block = False def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._blk_coord]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip([self.params, self._blk_coord], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileScheduler(*(tuple(obj_list)), loc=self._loc) class StaticPersistentTileScheduler: @dataclass class Params(ParamsBase): num_block_cluster_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor total_blocks_cluster: Int32 cluster_shape_m: cutlass.Constexpr[int] = 1 @staticmethod def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "StaticPersistentTileScheduler.Params": num_block_cluster = cute.ceil_div(args.num_block, cute.size(args.cluster_shape_mn)) total_blocks_cluster = num_block_cluster * args.num_head * args.num_batch return StaticPersistentTileScheduler.Params( FastDivmodDivisor(num_block_cluster), FastDivmodDivisor(args.num_head), total_blocks_cluster, cluster_shape_m=args.cluster_shape_mn[0], ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx self._loc = loc self._ip = ip @staticmethod def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] else: tile_idx = cute.arch.cluster_idx()[0] return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() # Grid must be a multiple of cluster_shape_m for CUDA cluster launch. max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m) return (grid_x, Int32(1), Int32(1)) # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks_cluster # if cute.arch.thread_idx()[0] == 0: # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): if const_expr(self.params.cluster_shape_m == 1): self._tile_idx += cute.arch.grid_dim()[0] else: self._tile_idx += cute.arch.cluster_dim()[0] def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( [self.params, self._tile_idx], self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return StaticPersistentTileScheduler(*(tuple(obj_list)), loc=self._loc) class SingleTileLPTScheduler: @dataclass class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 num_block: Int32 l2_minor: Int32 num_block_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 is_split_kv: cutlass.Constexpr[bool] = False @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTScheduler.Params": # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit # swizzle is how many heads can fit in L2 # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # Seems faster if swizzle if a power of 2 log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle num_hb_remainder = (args.num_head * args.num_batch) % swizzle return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, num_block=args.num_block, l2_minor=Int32(swizzle), num_block_divmod=FastDivmodDivisor(args.num_block), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), num_splits=args.num_splits, is_split_kv=args.is_split_kv, ) def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx self._loc = loc self._ip = ip @staticmethod def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @cute.jit def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: return (params.total_blocks, params.num_splits, Int32(1)) @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx, self._split_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*(tuple(obj_list)), loc=self._loc) class SingleTileLPTBwdScheduler: @dataclass class Params(ParamsBase): total_blocks: Int32 num_block: Int32 l2_minor: Int32 num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) spt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileLPTBwdScheduler.Params": size_l2 = 50 * 1024 * 1024 size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 size_one_dqaccum_head = 0 size_one_head = size_one_qdo_head + size_one_dqaccum_head log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) # swizzle = 8 # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle num_hb_remainder = (args.num_head * args.num_batch) % swizzle num_block = cute.ceil_div(args.num_block, args.cluster_shape_mn[0]) return SingleTileLPTBwdScheduler.Params( total_blocks=(num_block * args.cluster_shape_mn[0]) * args.num_head * args.num_batch, num_block=num_block, l2_minor=Int32(swizzle), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * num_block), l2_minor_residual_divmod=FastDivmodDivisor( max(num_hb_remainder, 1) ), # don't divide by 0 num_hb_quotient=Int32(num_hb_quotient), cluster_shape_mn=args.cluster_shape_mn, spt=args.lpt, ) def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx self._loc = loc self._ip = ip @staticmethod def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @cute.jit def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdScheduler": tile_idx = cute.arch.block_idx()[0] return SingleTileLPTBwdScheduler(params, tile_idx, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: return (params.total_blocks, Int32(1), Int32(1)) @cute.jit def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: cluster_idx = self._tile_idx // self.params.cluster_shape_mn[0] params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(cluster_idx, params.l2_major_divmod) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. block, bidhb_residual = 0, 0 if bidhb < params.num_hb_quotient: block, bidhb_residual = divmod(l2_mod, params.l2_minor_divmod) else: block, bidhb_residual = divmod(l2_mod, params.l2_minor_residual_divmod) bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) if cutlass.const_expr(params.spt): block = params.num_block - 1 - block if cutlass.const_expr(params.cluster_shape_mn[0] > 1): bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_mn[0] + bidx_in_cluster[0] is_valid = self._tile_idx < params.total_blocks return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip([self.params, self._tile_idx], self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return self.__class__(*(tuple(obj_list)), loc=self._loc) class SingleTileVarlenScheduler: @dataclass class Params(ParamsBase): num_head: Int32 num_batch: Int32 total_q: Int32 num_splits: Int32 max_kvblock_in_l2: Int32 tile_shape_mn: cutlass.Constexpr[Tuple[int, int]] mCuSeqlensQ: Optional[cute.Tensor] = None mSeqUsedQ: Optional[cute.Tensor] = None qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1 lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False cluster_shape_m: cutlass.Constexpr[int] = 1 @staticmethod @cute.jit def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V max_kvblock_in_l2 = size_l2 // ( (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] ) assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, total_q=args.total_q, num_splits=args.num_splits, max_kvblock_in_l2=max_kvblock_in_l2, tile_shape_mn=args.tile_shape_mn, mCuSeqlensQ=args.mCuSeqlensQ, mSeqUsedQ=args.mSeqUsedQ, qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa, lpt=args.lpt, is_split_kv=args.is_split_kv, head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], ) def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True self._loc = loc self._ip = ip @staticmethod def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) # called by host @staticmethod def get_grid_shape( params: Params, *, loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: total_blocks_max = ( params.total_q + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) ) // params.tile_shape_mn[0] # round down to nearest multiple of cluster since odd excess is always padding total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @cute.jit def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: params = self.params batch_idx = lane + bidb_start if cutlass.const_expr(params.mSeqUsedQ is not None): seqlen = Int32(0) if batch_idx < params.num_batch: seqlen = params.mSeqUsedQ[batch_idx] else: assert params.mCuSeqlensQ is not None cur_cu_seqlen = Int32(0) if batch_idx <= params.num_batch: cur_cu_seqlen = params.mCuSeqlensQ[batch_idx] next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) seqlen = next_cu_seqlen - cur_cu_seqlen if cutlass.const_expr(params.qhead_per_kvhead_packgqa > 1): seqlen *= params.qhead_per_kvhead_packgqa return ( cute.ceil_div(cute.ceil_div(seqlen, params.tile_shape_mn[0]), params.cluster_shape_m) if batch_idx < params.num_batch and lane < cute.arch.WARP_SIZE - 1 else Int32(0) ) @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) # Total number of blocks for the next 31 batches m_blocks_in_group = cute.arch.shuffle_sync(num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1) # Same for all lanes group_end_tile = m_blocks_in_group * params.num_head # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d", self._tile_idx, group_end_tile, num_m_blocks, num_m_blocks_cumulative, m_blocks_in_group) block, head_idx, batch_idx = Int32(0), Int32(0), Int32(0) next_tile_idx = self._tile_idx // params.cluster_shape_m while group_end_tile <= next_tile_idx: batch_idx += cute.arch.WARP_SIZE - 1 if batch_idx >= params.num_batch: batch_idx = Int32(params.num_batch) group_end_tile = next_tile_idx + 1 else: num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=batch_idx) num_m_blocks_cumulative = utils.warp_prefix_sum(num_m_blocks, lane_idx) m_blocks_in_group = cute.arch.shuffle_sync( num_m_blocks_cumulative, cute.arch.WARP_SIZE - 1 ) group_end_tile += m_blocks_in_group * params.num_head is_valid = False if batch_idx >= params.num_batch: block, head_idx, batch_idx = Int32(0), Int32(0), Int32(params.num_batch) else: group_start_tile = group_end_tile - m_blocks_in_group * params.num_head # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, group_end_tile = %d, num_m_blocks=%d, batch_idx = %d", self._tile_idx, group_end_tile, num_m_blocks, batch_idx) # The next problem to process is the first one that does not have ending tile position # that is greater than or equal to tile index. batch_idx_in_group = cute.arch.popc( cute.arch.vote_ballot_sync( group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx ) ) batch_idx += batch_idx_in_group num_m_blocks_prev_lane = ( 0 if batch_idx_in_group == 0 else cute.arch.shuffle_sync(num_m_blocks_cumulative, batch_idx_in_group - 1) ) num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group) mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head if cutlass.const_expr(params.lpt or params.head_swizzle): # This is a version of the SingleTileLPTScheduler, complicated by the fact that # the seqlen can vary per batch. # TODO: is there any case where num_m_blocks is 0? # TODO: by right we should read the seqlen_kv but we're assuming seqlen_q == seqlen_k here num_n_blocks = ( num_m_blocks * params.tile_shape_mn[0] // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] ) # nheads_in_l2 = min(max(self.max_kvblock_in_l2 // num_n_blocks, 1), self.num_head) # Seems faster to have this be a power of 2 nheads_in_l2 = ( 16 if num_n_blocks * 16 <= params.max_kvblock_in_l2 else ( 8 if num_n_blocks * 8 <= params.max_kvblock_in_l2 else ( 4 if num_n_blocks * 4 <= params.max_kvblock_in_l2 else (2 if num_n_blocks * 2 <= params.max_kvblock_in_l2 else 1) ) ) ) nheads_in_l2 = min(nheads_in_l2, params.num_head) mh_in_l2 = nheads_in_l2 * num_m_blocks section_idx = mh_block // mh_in_l2 l2_mod = mh_block - section_idx * mh_in_l2 # Deal with tail section nheads_in_this_section = ( nheads_in_l2 if nheads_in_l2 * (section_idx + 1) <= params.num_head else params.num_head - section_idx * nheads_in_l2 ) block = l2_mod // nheads_in_this_section head_idx_residual = l2_mod - block * nheads_in_this_section head_idx = section_idx * nheads_in_l2 + head_idx_residual if cutlass.const_expr(params.lpt): block = num_m_blocks - 1 - block else: head_idx = mh_block // num_m_blocks block = mh_block - head_idx * num_m_blocks is_valid = self._is_first_block and batch_idx < params.num_batch if cutlass.const_expr(params.cluster_shape_m > 1): bidx_in_cluster = cute.arch.block_in_cluster_idx() block = block * params.cluster_shape_m + bidx_in_cluster[0] # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, head_idx=%d, block=%d, is_valid = %d", self._tile_idx, batch_idx, head_idx, block, is_valid) split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) def initial_work_tile_info(self, *, loc=None, ip=None): return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): pass def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._is_first_block = False def __extract_mlir_values__(self): values, self._values_pos = [], [] for obj in [self.params, self._tile_idx, self._split_idx]: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) return values def __new_from_mlir_values__(self, values): obj_list = [] for obj, n_items in zip( [self.params, self._tile_idx, self._split_idx], self._values_pos, ): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) ================================================ FILE: flash_attn/cute/utils.py ================================================ # Copyright (c) 2025, Tri Dao. import math import hashlib import inspect from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute from cutlass import Float32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack import quack.activation _MIXER_ATTRS = ("__vec_size__",) # Obtained from sollya: # fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); POLY_EX2 = { 0: (1.0), 1: ( 1.0, 0.922497093677520751953125, ), 2: ( 1.0, 0.6657850742340087890625, 0.330107033252716064453125, ), 3: ( 1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625, ), 4: ( 1.0, 0.693042695522308349609375, 0.2412912547588348388671875, 5.2225358784198760986328125e-2, 1.3434938155114650726318359375e-2, ), 5: ( 1.0, 0.693151414394378662109375, 0.24016360938549041748046875, 5.5802188813686370849609375e-2, 9.01452265679836273193359375e-3, 1.86810153536498546600341796875e-3, ), } def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" try: data = inspect.getsource(func).encode() except (OSError, TypeError): if hasattr(func, "__code__") and func.__code__ is not None: data = func.__code__.co_code else: data = repr(func).encode() hasher = hashlib.sha256(data) if hasattr(func, "__closure__") and func.__closure__ is not None: for cell in func.__closure__: hasher.update(repr(cell.cell_contents).encode()) return hasher.hexdigest() def hash_callable( func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True ) -> str: """Hash a callable based on the source code or bytecode and closure values. Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` attribute, that value is returned immediately as the base hash, then metadata dunders are mixed in to produce the final dict-key hash. set_cute_hash: whether or not to set func.__cute_hash__ """ # Resolve base hash if hasattr(func, "__cute_hash__"): base_hash = func.__cute_hash__ else: # Unwrap decorated functions (e.g., cute.jit wrappers). base_func = getattr(func, "__wrapped__", func) if hasattr(base_func, "__cute_hash__"): base_hash = base_func.__cute_hash__ else: base_hash = _compute_base_hash(base_func) if set_cute_hash: base_func.__cute_hash__ = base_hash # Mix in mutable metadata dunders mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) if all(v is None for v in mixer_values): return base_hash hasher = hashlib.sha256(base_hash.encode()) for attr, val in zip(_MIXER_ATTRS, mixer_values): hasher.update(f"{attr}={val!r}".encode()) return hasher.hexdigest() def create_softcap_scoremod(softcap_val): inv_softcap = 1.0 / softcap_val @cute.jit def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): scores = acc_S_SSA * inv_softcap return scores * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn LOG2_E = math.log2(math.e) def compute_softmax_scale_log2(softmax_scale, score_mod): """Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used. When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale to None. When score_mod is present, keep softmax_scale separate so it can be applied before the score_mod, and set softmax_scale_log2 to just the change-of-base constant. Returns (softmax_scale_log2, softmax_scale). """ if const_expr(score_mod is None): return softmax_scale * LOG2_E, None else: return LOG2_E, softmax_scale def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None): """Compute FastDivmodDivisor pairs for aux_tensors index computation. Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None. """ if const_expr(aux_tensors is None): return None seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1) seqlen_k = ( cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1] ) return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k)) def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) .mark_layout_dynamic(leading_dim=leading_dim) .mark_compact_shape_dynamic( mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility ) ) def convert_from_dlpack_leading_static( x, leading_dim, alignment=16, static_modes=None, stride_order=None ) -> cute.Tensor: if stride_order is None: stride_order = x.dim_order() x_ = from_dlpack(x, assumed_align=alignment) for i in range(x.ndim): if i != leading_dim and (static_modes is None or i not in static_modes): x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) return x_ def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if const_expr(swapAB): return cute.make_tiled_copy_B(copy_atom, tiled_mma) else: return cute.make_tiled_copy_A(copy_atom, tiled_mma) def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if const_expr(swapAB): return cute.make_tiled_copy_A(copy_atom, tiled_mma) else: return cute.make_tiled_copy_B(copy_atom, tiled_mma) def mma_make_fragment_A( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: if const_expr(swapAB): return mma_make_fragment_B(smem, thr_mma) else: return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) def mma_make_fragment_B( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: if const_expr(swapAB): return mma_make_fragment_A(smem, thr_mma) else: return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) def get_smem_store_atom( arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False ) -> cute.CopyAtom: if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), element_type, num_bits_per_copy=2 * element_type.width, ) else: return cute.make_copy_atom( cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), element_type, ) @cute.jit def warp_reduce( val: cute.TensorSSA | cute.Numeric, op: Callable, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: if const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) for i in cutlass.range_constexpr(cute.size(val.shape)): res[i] = warp_reduce(res[i], op, width) return res.load() else: for i in cutlass.range_constexpr(int(math.log2(width))): val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) return val @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None ) -> Float32: from cutlass import CUDA_VERSION # * NVVM call based on nvvm version if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: # Old API: requires explicit result type as first positional argument return Float32( nvvm.fmax( T.f32(), Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, loc=loc, ip=ip, ) ) else: # New API: infers result type automatically return Float32( nvvm.fmax( Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, loc=loc, ip=ip, ) ) @cute.jit def fmax_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): # if const_expr(init_val is None): # init_val = -cutlass.Float32.if # return x.reduce(cute.ReductionOp.MAX, init_val, 0) res = cute.make_fragment(x.shape, Float32) res.store(x) # local_max = [res[0], res[1]] # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): # local_max[0] = fmax(local_max[0], res[i + 0]) # local_max[1] = fmax(local_max[1], res[i + 1]) # local_max[0] = fmax(local_max[0], local_max[1]) # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) local_max = [res[0], res[1], res[2], res[3]] for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): local_max[0] = fmax(local_max[0], res[i + 0]) local_max[1] = fmax(local_max[1], res[i + 1]) local_max[2] = fmax(local_max[2], res[i + 2]) local_max[3] = fmax(local_max[3], res[i + 3]) local_max[0] = fmax(local_max[0], local_max[1]) local_max[2] = fmax(local_max[2], local_max[3]) local_max[0] = fmax(local_max[0], local_max[2]) return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) local_max_0 = ( fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1]) ) local_max = [ local_max_0, fmax(res[2], res[3]), fmax(res[4], res[5]), fmax(res[6], res[7]), ] for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_max[0] = fmax(local_max[0], res[i], res[i + 1]) local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) local_max[0] = fmax(local_max[0], local_max[1]) return fmax(local_max[0], local_max[2], local_max[3]) @cute.jit def fadd_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): if const_expr(init_val is None): init_val = Float32.zero return x.reduce(cute.ReductionOp.ADD, init_val, 0) # res = cute.make_fragment(x.shape, Float32) # res.store(x) # local_sum = [res[0], res[1], res[2], res[3]] # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): # local_sum[0] += res[i + 0] # local_sum[1] += res[i + 1] # local_sum[2] += res[i + 2] # local_sum[3] += res[i + 3] # local_sum[0] += local_sum[1] # local_sum[2] += local_sum[3] # local_sum[0] += local_sum[2] # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val else: res = cute.make_fragment(x.shape, Float32) res.store(x) local_sum_0 = ( cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) if const_expr(init_val is not None) else (res[0], res[1]) ) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) return local_sum[0][0] + local_sum[0][1] @dsl_user_op def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # # cache_hint = cutlass.Int64(0x12F0000000000000) # llvm.inline_asm( # None, # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], # "red.global.add.f32 [$0], $1;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", # "l,f", # # "l,f,l", # has_side_effects=True, # is_align_stack=False, # asm_dialect=llvm.AsmDialect.AD_ATT, # ) nvvm.atomicrmw( res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() ) @dsl_user_op def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) @cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" tApA = cute.make_fragment( cute.make_layout( (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) for rest_v in cutlass.range_constexpr(tApA.shape[0]): for rest_k in cutlass.range_constexpr(tApA.shape[2]): tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) return tApA def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: warp_group_idx = cute.arch.thread_idx()[0] // 128 if const_expr(sync): warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) return warp_group_idx # @dsl_user_op # def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: # mask = cutlass.Int32(-1) # return cutlass.Boolean( # llvm.inline_asm( # T.i32(), # [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], # ".pred p1, p2;\n" # "setp.lt.f32 p1, $1, $2;\n" # "vote.sync.any.pred p2, p1, $3;\n" # "selp.u32 $0, 1, 0, p2;", # # "selp.u32 $0, 1, 0, p1;", # "=r,f,f,r", # has_side_effects=False, # is_align_stack=False, # asm_dialect=llvm.AsmDialect.AD_ATT, # ) # ) @cute.jit def shuffle_sync( value: cute.Numeric, offset: cute.typing.Int, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.Numeric: assert value.width % 32 == 0, "value type must be a multiple of 32 bits" # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 mask = cute.arch.WARP_SIZE - width clamp = cute.arch.WARP_SIZE - 1 mask_and_clamp = mask << 8 | clamp # important: need stride 1 and not 0 for recast_tensor to work val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) val[0] = value val_i32 = cute.recast_tensor(val, cutlass.Int32) for i in cutlass.range_constexpr(cute.size(val_i32)): val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) return val[0] @dsl_user_op def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: """ Left-shift val by shift bits using PTX shl.b32 (sign-agnostic). Named ``shl_u32`` (not ``shl_b32``) because python type annotations distinguish signed/unsigned. PTX semantics (§9.7.8.8): "Shift amounts greater than the register width N are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0. This differs from C/C++ and LLVM IR, where shifting by >= the type width is undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer may treat the result as poison and eliminate dependent code. Inline PTX bypasses the LLVM IR shift entirely — the instruction is emitted verbatim into PTX where clamping makes it safe for all shift amounts. """ return cutlass.Uint32( llvm.inline_asm( T.i32(), [ cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), ], "shl.b32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: """ Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills). See ``shl_u32`` docstring for why inline PTX is used instead of plain CuTeDSL shift operators (LLVM shift-by-type-width UB). """ return cutlass.Uint32( llvm.inline_asm( T.i32(), [ cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), ], "shr.u32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @cute.jit def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: if const_expr(lane is None): lane = cute.arch.lane_idx() # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): offset = 1 << i # Very important that we set mask_and_clamp to 0 partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) if lane >= offset: val += partial_sum # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) return val @dsl_user_op def cvt_f16x2_f32( a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None ) -> cutlass.Int32: assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" return cutlass.Int32( llvm.inline_asm( T.i32(), [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", "=r,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @overload def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... @overload def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... @cute.jit def cvt_f16(src: cute.Tensor, dst_or_dtype): """Convert Float32 tensor to Float16/BFloat16. Args: src: Source tensor with Float32 element type dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) Returns: None if dst is a tensor, or a new tensor if dtype is provided """ if const_expr(isinstance(dst_or_dtype, type)): # dtype variant: create new tensor and call the tensor variant dtype = dst_or_dtype dst = cute.make_fragment(src.shape, dtype) cvt_f16(src, dst) return dst else: # tensor variant: write to dst dst = dst_or_dtype assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( "dst must be BFloat16 or Float16" ) assert src.element_type is Float32, "src must be Float32" dst_i32 = cute.recast_tensor(dst, cutlass.Int32) assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) for i in cutlass.range_constexpr(cute.size(dst_i32)): dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) @dsl_user_op @cute.jit def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: deg = len(poly) - 1 out = poly[deg] for i in cutlass.range_constexpr(deg - 1, -1, -1): out = out * x + poly[i] return out @dsl_user_op @cute.jit def evaluate_polynomial_2( x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None ) -> Tuple[Float32, Float32]: deg = len(poly) - 1 out = (poly[deg], poly[deg]) for i in cutlass.range_constexpr(deg - 1, -1, -1): out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) return out @dsl_user_op def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: # There's probably a way to call llvm or nvvm to do this instead of ptx return cutlass.Float32( llvm.inline_asm( T.f32(), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], "add.rm.ftz.f32 $0, $1, $2;", "=f,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: return cutlass.Float32( llvm.inline_asm( T.f32(), [ Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip), ], "{\n\t" ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" "mov.b32 x_rounded_i, $1;\n\t" "mov.b32 frac_ex_i, $2;\n\t" "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" # add.u32 generates IMAD instruction and add.s32 generates LEA instruction # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" "mov.b32 $0, out_i;\n\t" "}\n", "=f,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" # We assume x <= 127.0 fp32_round_int = float(2**23 + 2**22) x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) # The integer floor of x is now in the last 8 bits of x_rounded # We assume the next 2 ops round to nearest even. The rounding mode is important. x_rounded_back = x_rounded - fp32_round_int x_frac = x_clamped - x_rounded_back x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version @dsl_user_op def ex2_emulation_2( x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None ) -> Tuple[Float32, Float32]: # We assume x <= 127.0 and y <= 127.0 fp32_round_int = float(2**23 + 2**22) xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") # The integer floor of x & y are now in the last 8 bits of xy_rounded # We want the next 2 ops to round to nearest even. The rounding mode is important. xy_rounded_back = quack.activation.sub_packed_f32x2( xy_rounded, (fp32_round_int, fp32_round_int) ) xy_frac = quack.activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) return x_out, y_out @dsl_user_op def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: out_f32x2 = llvm.inline_asm( llvm.StructType.get_literal([T.f32(), T.f32()]), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], "{\n\t" ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" "mov.b64 l1, {f1, f2};\n\t" "mov.f32 f3, 0f4B400000;\n\t" "mov.b64 l2, {f3, f3};\n\t" "add.rm.ftz.f32x2 l7, l1, l2;\n\t" "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" "mov.f32 f7, 0f3D9DF09D;\n\t" "mov.b64 l6, {f7, f7};\n\t" "mov.f32 f6, 0f3E6906A4;\n\t" "mov.b64 l5, {f6, f6};\n\t" "mov.f32 f5, 0f3F31F519;\n\t" "mov.b64 l4, {f5, f5};\n\t" "mov.f32 f4, 0f3F800000;\n\t" "mov.b64 l3, {f4, f4};\n\t" "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" "mov.b64 {r1, r2}, l7;\n\t" "mov.b64 {r3, r4}, l10;\n\t" "shl.b32 r5, r1, 23;\n\t" "add.s32 r7, r5, r3;\n\t" "shl.b32 r6, r2, 23;\n\t" "add.s32 r8, r6, r4;\n\t" "mov.b32 $0, r7;\n\t" "mov.b32 $1, r8;\n\t" "}\n", "=r,=r,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 @dsl_user_op def domain_offset_aligned( coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None ) -> cute.Tensor: assert isinstance(tensor.iterator, cute.Pointer) # We assume that applying the offset does not change the pointer alignment new_ptr = cute.make_ptr( tensor.element_type, elem_pointer(tensor, coord).toint(), tensor.memspace, assumed_align=tensor.iterator.alignment, ) return cute.make_tensor(new_ptr, tensor.layout) @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" vec = cute.make_fragment(1, dtype) vec[0] = a return vec.load() def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0] ================================================ FILE: flash_attn/flash_attn_interface.py ================================================ # Copyright (c) 2023, Tri Dao. from typing import Optional, Sequence, Tuple, Union import torch import torch.nn as nn import os import warnings # isort: off # We need to import the CUDA kernels after importing torch USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if not USE_TRITON_ROCM and getattr(torch.version, 'hip', None) is not None: try: import flash_attn_2_cuda except ImportError: warnings.warn("flash_attn_2_cuda (which has ROCm/HIP kernels) not found, falling back to Triton implementation") USE_TRITON_ROCM = True if USE_TRITON_ROCM: from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu else: import flash_attn_2_cuda as flash_attn_gpu # isort: on def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) is_sm80 = major == 8 and minor == 0 is_sm90 = major == 9 and minor == 0 if head_dim <= 32: return 128 if head_dim <= 64: return 128 if not is_dropout else 64 elif head_dim <= 96: return 64 elif head_dim <= 128: if is_sm8x: return 64 if (not is_dropout and is_causal) else 32 else: return 64 if not is_dropout else 32 elif head_dim <= 192: return 64 elif head_dim <= 224: return 64 elif head_dim <= 256: return 64 def round_multiple(x, m): return (x + m - 1) // m * m # torch.compile() support is only enabled for pytorch >= 2.4 # The reason for this is that we are using the new custom_op and register_fake # APIs, which support inplace modification of inputs in the function itself if torch.__version__ >= "2.4.0": _torch_custom_op_wrapper = torch.library.custom_op _torch_register_fake_wrapper = torch.library.register_fake else: def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): def wrap(func): return func if fn is None: return wrap return fn def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): def wrap(func): return func if fn is None: return wrap return fn _torch_custom_op_wrapper = noop_custom_op_wrapper _torch_register_fake_wrapper = noop_register_fake_wrapper @_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], return_softmax: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( q, k, v, None, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, softcap, return_softmax, None, ) return out, softmax_lse, S_dmask, rng_state @_torch_register_fake_wrapper("flash_attn::_flash_attn_forward") def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], return_softmax: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] batch_size, seqlen_q, num_heads, head_size = q.shape seqlen_k = k.shape[1] out = torch.empty_like(q) softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: if torch.cuda.is_available() and torch.version.hip: p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) else: p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state if torch.__version__ >= "2.4.0": _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward else: _wrapped_flash_attn_forward = _flash_attn_forward @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda") def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, return_softmax: bool = False, block_table: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( q, k, v, None, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, block_table, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, zero_tensors, causal, window_size_left, window_size_right, softcap, return_softmax, None, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() return out, softmax_lse, S_dmask, rng_state @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward") def _flash_attn_varlen_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, alibi_slopes: Optional[torch.Tensor] = None, return_softmax: bool = False, block_table: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] paged_kv = block_table is not None batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape out = torch.empty_like(q) softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: if torch.cuda.is_available() and torch.version.hip: p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout) else: p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) return out, softmax_lse, p, rng_state if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward else: _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward @_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, dk, dv, softmax_d, ) = flash_attn_gpu.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, dropout_p, softmax_scale, causal, window_size_left, window_size_right, softcap, deterministic, None, rng_state, ) return softmax_d @_torch_register_fake_wrapper("flash_attn::_flash_attn_backward") def _flash_attn_backward_fake( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] if dq is None: dq = torch.empty_like(q) if dk is None: dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) batch_size, seqlen_q, num_heads, _ = q.shape if torch.cuda.is_available() and torch.version.hip: softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32) else: softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32) return softmax_d if torch.__version__ >= "2.4.0": _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward else: _wrapped_flash_attn_backward = _flash_attn_backward @_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_varlen_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( dq, dk, dv, softmax_d, ) = flash_attn_gpu.varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, zero_tensors, causal, window_size_left, window_size_right, softcap, deterministic, None, rng_state, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return softmax_d @_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward") def _flash_attn_varlen_backward_fake( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor], dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, softmax_scale: float, causal: bool, window_size_left: int, window_size_right: int, softcap: float, alibi_slopes: Optional[torch.Tensor], deterministic: bool, rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] batch_size = cu_seqlens_q.numel() - 1 total_q, num_heads, _ = q.shape if dq is None: dq = torch.empty_like(q) if dk is None: dk = torch.empty_like(k) if dv is None: dv = torch.empty_like(v) if torch.cuda.is_available() and torch.version.hip: softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32) else: softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) return softmax_d if torch.__version__ >= "2.4.0": _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward else: _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, qkv, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) if is_grad: ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_backward( dout_padded, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension return dqkv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled, ): is_grad = is_grad_enabled and qkv.requires_grad if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() head_size_og = q.size(2) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, ) if is_grad: ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_varlen_backward( dout_padded, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension return dqkv, None, None, None, None, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, kv, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach() head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) if is_grad: ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_backward( dout_padded, q, k, v, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] return dq, dkv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, kv] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) k, v = kv[:, 0].detach(), kv[:, 1].detach() head_size_og = q.size(2) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, ) if is_grad: ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors dq = torch.empty_like(q) kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_varlen_backward( dout_padded, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, is_grad_enabled, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_size_og = q.size(3) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) if is_grad: ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) head_size_og = dout.size(3) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_backward( dout_padded, q, k, v, out, softmax_lse, dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, block_table, is_grad_enabled, ): is_grad = is_grad_enabled and any( x.requires_grad for x in [q, k, v] ) if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_size_og = q.size(2) if head_size_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, ) if is_grad: ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state ) ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic out = out_padded[..., :head_size_og] return out if not return_softmax else (out, softmax_lse, S_dmask) @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) head_size_og = dout.size(2) dout_padded = dout if head_size_og % 8 != 0: dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) _wrapped_flash_attn_varlen_backward( dout_padded, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # <=0.0 means deactivate alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_kvpacked_func and flash_attn_func. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnQKVPackedFunc.apply( qkv, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, torch.is_grad_enabled(), ) def flash_attn_kvpacked_func( q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnKVPackedFunc.apply( q, kv, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, torch.is_grad_enabled(), ) def flash_attn_func( q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnFunc.apply( q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, torch.is_grad_enabled(), ) def flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnVarlenQKVPackedFunc.apply( qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, torch.is_grad_enabled(), ) def flash_attn_varlen_kvpacked_func( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnVarlenKVPackedFunc.apply( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, torch.is_grad_enabled(), ) def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, block_table=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, block_table, torch.is_grad_enabled(), ) def flash_attn_with_kvcache( q, k_cache, v_cache, k=None, v=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, alibi_slopes=None, num_splits=0, return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from k and v. This is useful for incremental decoding: you can pass in the cached keys/values from the previous step, and update them with the new keys/values from the current step, and do attention with the updated cache, all in 1 kernel. If you pass in k / v, you must make sure that the cache is large enough to hold the new values. For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Note: Does not support backward pass. Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) page_block_size must be a multiple of 256. v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) out, softmax_lse = flash_attn_gpu.fwd_kvcache( q, k_cache, v_cache, k, v, cache_seqlens, rotary_cos, rotary_sin, cache_batch_idx, cache_leftpad, block_table, alibi_slopes, None, softmax_scale, causal, window_size[0], window_size[1], softcap, rotary_interleaved, num_splits, ) return (out, softmax_lse) if return_softmax_lse else out ================================================ FILE: flash_attn/flash_attn_triton.py ================================================ """ *Experimental* implementation of FlashAttention in Triton. Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: - Implement both causal and non-causal attention. - Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Support attention bias. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. Caution: - This is an *experimental* implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). - This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident that there are none left for other head dimensions. Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward, while Triton backward is generally slower than CUDA backward. Overall Triton forward + backward is slightly slower than CUDA forward + backward. - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version supports attention bias, while CUDA version doesn't. """ import math import torch import triton import triton.language as tl # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 # @triton.autotune( # configs=[ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), # # This config has a race condition when EVEN_M == False, disabling it for now. # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), # ], # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] # ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _fwd_kernel( Q, K, V, Bias, Out, Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # off_b = tl.program_id(1) # off_h = tl.program_id(2) # off_hb = off_b * nheads + off_h # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # Initialize pointers to Q, K, V # Adding parenthesis around indexing might use int32 math instead of int64 math? # https://github.com/openai/triton/issues/741 # I'm seeing a tiny bit of difference (5-7us) q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) if BIAS_TYPE == "vector": b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n elif BIAS_TYPE == "matrix": b_ptrs = ( Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) ) # initialize pointer to m and l t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # load q: it will stay in SRAM throughout # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call # tl.load(q_ptrs), we get the wrong output! if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) # loop over k, v and update accumulator end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) if BIAS_TYPE != "none": if BIAS_TYPE == "vector": if EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load( b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 ).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == "matrix": if EVEN_M & EVEN_N: bias = tl.load(b_ptrs + start_n).to(tl.float32) else: bias = tl.load( b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0, ).to(tl.float32) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. qk = qk * softmax_scale + bias m_ij = tl.maximum(tl.max(qk, 1), lse_i) p = tl.exp(qk - m_ij[:, None]) else: m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) # scale acc_o acc_o_scale = tl.exp(m_i - m_ij) # # -- update output accumulator -- # BUG: have to store and immediately load tl.store(t_ptrs, acc_o_scale) acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] # update acc_o if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) p = p.to(v.dtype) acc_o += tl.dot(p, v) # -- update statistics m_i = m_ij l_i_new = tl.exp(lse_i - m_ij) + l_ij lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) # BUG: have to store and immediately load tl.store(t_ptrs, o_scale) o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] # rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) # initialize pointers to output offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store( out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) ) @triton.jit def _bwd_preprocess_do_o_dot( Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # load o = tl.load( Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ).to(tl.float32) do = tl.load( DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @triton.jit def _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, ): # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, # if we just call tl.store(dv_ptrs), there's a race condition if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M # initialize row/col offsets offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) if BIAS_TYPE == "vector": b_ptrs = Bias + offs_n elif BIAS_TYPE == "matrix": b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) # initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # There seems to be some problem with Triton pipelining that makes results wrong for # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop # may have zero step, and pipelining with the bias matrix could screw it up. # So we just exit early. if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, ) return # k and v stay in SRAM throughout # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, # if we just call tl.load(k_ptrs), we get the wrong output! if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load( k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) # loop over rows num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m # load q, k, v, do on-chip # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) if EVEN_M & EVEN_HEADDIM: q = tl.load(q_ptrs) else: if EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: q = tl.load( q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ) # recompute p = softmax(qk, dim=-1).T qk = tl.dot(q, k, trans_b=True) # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) if BIAS_TYPE != "none": tl.debug_barrier() # Race condition otherwise if BIAS_TYPE == "vector": if EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) bias = bias[None, :] elif BIAS_TYPE == "matrix": if EVEN_M & EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) else: bias = tl.load( b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0, ).to(tl.float32) qk = qk * softmax_scale + bias # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # Also wrong for headdim=64. if not (EVEN_M & EVEN_HEADDIM): tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) if BIAS_TYPE == "none": p = tl.exp(qk * softmax_scale - lse_i[:, None]) else: p = tl.exp(qk - lse_i[:, None]) # compute dv # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, # the output is correct. if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) else: # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. do = tl.load( do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, ) # if EVEN_M: # if EVEN_HEADDIM: # do = tl.load(do_ptrs) # else: # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) # else: # if EVEN_HEADDIM: # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) # else: # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) # & (offs_d[None, :] < headdim), other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False if not (EVEN_M & EVEN_HEADDIM): tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) # There's a race condition for headdim=48 if not EVEN_HEADDIM: tl.debug_barrier() # compute ds = p * (dp - delta[:, None]) # Putting the subtraction after the dp matmul (instead of before) is slightly faster Di = tl.load(D + offs_m_curr) # Converting ds to q.dtype here reduces register pressure and makes it much faster # for BLOCK_HEADDIM=128 ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) # compute dk = dot(ds.T, q) dk += tl.dot(ds, q, trans_a=True) # compute dq if not ( EVEN_M & EVEN_HEADDIM ): # Otherewise there's a race condition when BIAS_TYPE='matrix' tl.debug_barrier() if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, eviction_policy="evict_last") else: if EVEN_HEADDIM: dq = tl.load( dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last", ) dq += tl.dot(ds, k) tl.store( dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last", ) else: dq = tl.load( dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last", ) dq += tl.dot(ds, k) tl.store( dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last", ) else: # If we're parallelizing across the seqlen_k dimension dq = tl.dot(ds, k) if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M tl.atomic_add(dq_ptrs, dq) else: if EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) else: tl.atomic_add( dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), ) # increment pointers dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom if BIAS_TYPE == "matrix": b_ptrs += BLOCK_M * stride_bm # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) _bwd_store_dk_dv( dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, ) def init_to_zero(name): return lambda nargs: nargs[name].zero_() @triton.autotune( configs=[ triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ"), ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ"), ), # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), ], key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_kernel( Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # offset pointers for batch/head Q += off_b * stride_qb + off_h * stride_qh K += off_b * stride_kb + off_h * stride_kh V += off_b * stride_vb + off_h * stride_vh DO += off_b * stride_dob + off_h * stride_doh DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh if BIAS_TYPE != "none": Bias += off_b * stride_bb + off_h * stride_bh # pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) else: start_n = tl.program_id(0) _bwd_kernel_one_col_block( start_n, Q, K, V, Bias, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qm, stride_kn, stride_vn, stride_bm, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, BIAS_TYPE=BIAS_TYPE, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape assert k.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d) assert d <= 128, "FlashAttention only support head dimensions up to 128" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(d) has_bias = bias is not None bias_type = "none" if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 if bias.stride(-1) != 1: bias = bias.contiguous() if bias.shape[2:] == (1, seqlen_k): bias_type = "vector" elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = "matrix" else: raise RuntimeError( "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" ) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK = 128 num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_kernel[grid]( q, k, v, bias, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, bias_type, causal, BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, ) return o, lse, softmax_scale # softmax_scale could have been updated def _flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None ): # Make sure that the last dimension is contiguous if do.stride(-1) != 1: do = do.contiguous() batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape # assert d in {16, 32, 64, 128} assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(d) # dq_accum = torch.zeros_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse) # delta = torch.zeros_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _bwd_preprocess_do_o_dot[grid]( o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, ) has_bias = bias is not None bias_type = "none" if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda assert bias.dim() == 4 assert bias.stride(-1) == 1 if bias.shape[2:] == (1, seqlen_k): bias_type = "vector" elif bias.shape[2:] == (seqlen_q, seqlen_k): bias_type = "matrix" else: raise RuntimeError( "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" ) bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) # BLOCK_M = 128 # BLOCK_N = 64 # num_warps = 4 grid = lambda META: ( triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads, ) _bwd_kernel[grid]( q, k, v, bias, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), *bias_strides, do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, bias_type, causal, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # num_warps=num_warps, # num_stages=1, ) dq.copy_(dq_accum) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): """ qkv: (batch, seqlen, 3, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) """ # Make sure that the last dimension is contiguous if qkv.stride(-1) != 1: qkv = qkv.contiguous() o, lse, ctx.softmax_scale = _flash_attn_forward( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale, ) ctx.save_for_backward(qkv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): qkv, o, lse, bias = ctx.saved_tensors assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet" # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): dqkv = torch.empty_like(qkv) _flash_attn_backward( do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return dqkv, None, None, None flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): """ q: (batch, seqlen_q, nheads, headdim) kv: (batch, seqlen_k, 2, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ # Make sure that the last dimension is contiguous q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] o, lse, ctx.softmax_scale = _flash_attn_forward( q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, kv, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): q, kv, o, lse, bias = ctx.saved_tensors if len(ctx.needs_input_grad) >= 3: assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet" # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward( do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return dq, dkv, None, None, None flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) k, v: (batch_size, seqlen_k, nheads, headdim) bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) """ # Make sure that the last dimension is contiguous q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] o, lse, ctx.softmax_scale = _flash_attn_forward( q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, k, v, o, lse, bias) ctx.causal = causal return o @staticmethod def backward(ctx, do): q, k, v, o, lse, bias = ctx.saved_tensors assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet" # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) _flash_attn_backward( do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) return dq, dk, dv, None, None, None flash_attn_func = FlashAttnFunc.apply ================================================ FILE: flash_attn/flash_attn_triton_og.py ================================================ # [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py # for benchmarking. # We fixed a few dtype cast to make it work for bf16 """ Fused Attention =============== This is a Triton implementation of the Flash Attention algorithm (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) """ import pytest import torch import triton import triton.language as tl @triton.jit def _fwd_kernel( Q, K, V, sm_scale, TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk # Initialize pointers to Q, K, V q_ptrs = Q + off_q k_ptrs = K + off_k v_ptrs = V + off_v # initialize pointer to m and l t_ptrs = TMP + off_hz * N_CTX + offs_m m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # load q: it will stay in SRAM throughout q = tl.load(q_ptrs) # loop over k, v and update accumulator for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + start_n * stride_kn) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) qk *= sm_scale qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) # -- update m_i and l_i m_i_new = tl.maximum(m_i, m_ij) alpha = tl.exp(m_i - m_i_new) beta = tl.exp(m_ij - m_i_new) l_i_new = alpha * l_i + beta * l_ij # -- update output accumulator -- # scale p p_scale = beta / l_i_new p = p * p_scale[:, None] # scale acc acc_scale = l_i / l_i_new * alpha tl.store(t_ptrs, acc_scale) acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load acc = acc * acc_scale[:, None] # update acc v = tl.load(v_ptrs + start_n * stride_vk) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new # rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m l_ptrs = L + off_hz * N_CTX + offs_m m_ptrs = M + off_hz * N_CTX + offs_m tl.store(l_ptrs, l_i) tl.store(m_ptrs, m_i) # initialize pointers to output offs_n = tl.arange(0, BLOCK_DMODEL) off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on out_ptrs = Out + off_o tl.store(out_ptrs, acc) @triton.jit def _bwd_preprocess( Out, DO, L, NewDO, Delta, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) denom = tl.load(L + off_m).to(tl.float32) # compute do = do / denom[:, None] delta = tl.sum(o * do, axis=1) # write-back tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) @triton.jit def _bwd_kernel( Q, K, V, sm_scale, Out, DO, DQ, DK, DV, L, M, D, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, Z, H, N_CTX, num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H # offset pointers for batch/head Q += off_z * stride_qz + off_h * stride_qh K += off_z * stride_qz + off_h * stride_qh V += off_z * stride_qz + off_h * stride_qh DO += off_z * stride_qz + off_h * stride_qh DQ += off_z * stride_qz + off_h * stride_qh DK += off_z * stride_qz + off_h * stride_qh DV += off_z * stride_qz + off_h * stride_qh for start_n in range(0, num_block): lo = start_n * BLOCK_M # initialize row/col offsets offs_qm = lo + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_DMODEL) # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) # pointer to row-wise quantities in value-like data D_ptrs = D + off_hz * N_CTX m_ptrs = M + off_hz * N_CTX # initialize dv amd dk dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # k and v stay in SRAM throughout k = tl.load(k_ptrs) v = tl.load(v_ptrs) # loop over rows for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): offs_m_curr = start_m + offs_m # load q, k, v, do on-chip q = tl.load(q_ptrs) # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here qk = tl.dot(q, k, trans_b=True) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk * sm_scale - m[:, None]) # compute dv do = tl.load(do_ptrs) dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] dp += tl.dot(do, v, trans_b=True) # compute ds = p * (dp - delta[:, None]) ds = p * dp * sm_scale # compute dk = dot(ds.T, q) dk += tl.dot(ds.to(q.dtype), q, trans_a=True) # # compute dq dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq += tl.dot(ds.to(k.dtype), k) tl.store(dq_ptrs, dq, eviction_policy="evict_last") # # increment pointers dq_ptrs += BLOCK_M * stride_qm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm # write-back dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sm_scale): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) tmp = torch.empty( (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 ) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( q, k, v, sm_scale, tmp, L, m, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, num_warps=num_warps, num_stages=1, ) ctx.save_for_backward(q, k, v, o, L, m) ctx.BLOCK = BLOCK ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk return o @staticmethod def backward(ctx, do): q, k, v, o, l, m = ctx.saved_tensors do = do.contiguous() dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty_like(k) dv = torch.empty_like(v) do_scaled = torch.empty_like(do) delta = torch.empty_like(l) _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( o, do, l, do_scaled, delta, BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, ) # NOTE: kernel currently buggy for other values of `num_warps` num_warps = 8 _bwd_kernel[(ctx.grid[1],)]( q, k, v, ctx.sm_scale, o, do_scaled, dq, dk, dv, l, m, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), q.shape[0], q.shape[1], q.shape[2], ctx.grid[0], BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) return dq.to(q.dtype), dk, dv, None attention = _attention.apply ================================================ FILE: flash_attn/flash_blocksparse_attention.py ================================================ import math import hydra import torch import torch.nn as nn from einops import rearrange from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input from flash_attn.flash_blocksparse_attn_interface import ( convert_blockmask, flash_blocksparse_attn_func, ) class FlashBlocksparseAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_temp: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.1) """ def __init__( self, sparsity_config, softmax_temp=None, attention_dropout=0.0, max_seq_length=2048, device=None, dtype=None, ): super().__init__() self.sparsity_config = hydra.utils.instantiate(sparsity_config) self.softmax_temp = softmax_temp self.dropout_p = attention_dropout # initialize sparse layout and register as buffer max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256 layout = self.sparsity_config.make_layout(max_seq_length) self.register_buffer("layout", layout) blockmask_converted = convert_blockmask(self.layout, causal=False) self.register_buffer("blockmask_converted", blockmask_converted) # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') def forward( self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False, convert_mask=True, ): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None attn_mask: An implementation of BaseMask that encodes where each query can attend to key_padding_mask: An implementation of BaseMask that encodes how many query each sequence in the batch consists of """ assert not need_weights assert attn_mask is None assert qkv.dtype == torch.float16 assert qkv.is_cuda if cu_seqlens is None: batch_size = qkv.shape[0] seqlen = qkv.shape[1] # Convert mask to take a subset seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 assert seqlen_rounded // 16 <= self.layout.shape[0], ( seqlen_rounded // 256 <= self.layout.shape[1] ) blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] if key_padding_mask is None: qkv = rearrange(qkv, "b s ... -> (b s) ...") max_s = seqlen cu_seqlens = torch.arange( 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device ) output = flash_blocksparse_attn_func( qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, max_s, softmax_scale=self.softmax_temp, causal=causal, ) output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) else: key_padding_mask_bool = key_padding_mask.bool_matrix nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_blocksparse_attn_func( x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, max_s, softmax_scale=self.softmax_temp, causal=causal, ) output = rearrange( pad_input( rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen ), "b s (h d) -> b s h d", h=nheads, ) else: assert max_s is not None seqlen = max_s # Convert mask to take a subset seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 assert seqlen_rounded // 16 <= self.layout.shape[0], ( seqlen_rounded // 256 <= self.layout.shape[1] ) blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] if convert_mask: output = flash_blocksparse_attn_func( qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, max_s, softmax_scale=self.softmax_temp, causal=causal, ) else: output = flash_blocksparse_attn_func( qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0, max_s, softmax_scale=self.softmax_temp, causal=causal, convert_mask=False, ) return output, None class FlashBlocksparseMHA(nn.Module): def __init__( self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True, attention_dropout=0.0, causal=False, max_seq_length=2048, device=None, dtype=None, **kwargs, ) -> None: assert batch_first factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.inner_attn = FlashBlocksparseAttention( sparsity_config, attention_dropout=attention_dropout, max_seq_length=max_seq_length, **factory_kwargs, ) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) def forward( self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False ): qkv = self.Wqkv(x) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) context, attn_weights = self.inner_attn( qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal ) return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights ================================================ FILE: flash_attn/flash_blocksparse_attn_interface.py ================================================ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py import flash_attn_cuda import torch import torch.nn as nn def convert_blockmask(blockmask, causal): """Convert from the 0-1 format to the format used by the CUDA code. 0 means the block is skipped. nonzero means the block is not skipped. Argument: blockmask: (row, col): a 0-1 tensor Return: blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row indices of the nonzero blocks, padded with -1 to reach length @row. The indices are multiplied by 4, with the smallest bit used to encode whether it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is the last nonzero in its row.. """ assert not causal # TD [2022-05-13]: The indexing and sorting is very tricky nrow, ncol = blockmask.shape # Sort does not support bool on CUDA blockmask = blockmask.to(dtype=torch.uint8) nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True) nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0) last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1] last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row ] first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0] first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row ] nonzero_idx = nonzero_sorted_rowidx * 4 nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2 nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1 nonzero_idx[nonzero_val == 0] = -1 return nonzero_idx.T.contiguous().to(dtype=torch.int32) def _flash_blocksparse_attn_forward( qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax ): context, softmax_lse, *rest = flash_attn_cuda.fwd_block( qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None ) # if context.isnan().any() or softmax_lse.isnan().any(): # breakpoint() S_dmask = rest[0] if return_softmax else None return context, softmax_lse, S_dmask def _flash_blocksparse_attn_backward( dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, ): dqkv, dp, softmax_d = flash_attn_cuda.bwd_block( dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask, dropout_p, softmax_scale, max_s, causal, None, ) # if dqkv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dqkv class FlashBlocksparseAttnFun(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): # Save rng_state because the backward pass will regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False, ) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) ctx.dropout_p = dropout_p ctx.max_s = max_s ctx.softmax_scale = softmax_scale ctx.causal = causal return context @staticmethod def backward(ctx, dout): qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors if rng_state is not None: cur_rng_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(rng_state) # S_dmask is None, temporarily use another tensor just to get it running dqkv = _flash_blocksparse_attn_backward( dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, ctx.max_s, ctx.softmax_scale, ctx.causal, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None, None, None, None # We duplicate code to return both the output and the softmax for testing # Returning both makes backward a bit slower, so we want to keep using the other version for speed. class FlashBlocksparseAttnFunWithS(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): # Save rng_state because the backward pass is gonna regenerate the dropout mask rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True, ) ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) ctx.dropout_p = dropout_p ctx.max_s = max_s ctx.softmax_scale = softmax_scale ctx.causal = causal return context, S_dmask, softmax_lse @staticmethod def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors if rng_state is not None: cur_rng_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(rng_state) dqkv = _flash_blocksparse_attn_backward( dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, ctx.max_s, ctx.softmax_scale, ctx.causal, ) if rng_state is not None: torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None, None, None def flash_blocksparse_attn_func( qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None, causal=False, return_attn_probs=False, convert_mask=True, ): """dropout_p should be set to 0.0 during evaluation""" func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS if convert_mask: blockmask = convert_blockmask(blockmask, causal=causal) return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) ================================================ FILE: flash_attn/layers/__init__.py ================================================ ================================================ FILE: flash_attn/layers/patch_embed.py ================================================ # We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py # But we use nn.Linear instead of Conv2d and it's about 8x faster. from functools import partial import torch.nn as nn from einops import rearrange from torch import _assert from torch.nn.modules.utils import _pair try: from flash_attn.ops.fused_dense import FusedDense except ImportError: FusedDense = None class PatchEmbed(nn.Module): """2D Image to Patch Embedding""" def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True, fused_bias_fc=False, ): super().__init__() img_size = _pair(img_size) patch_size = _pair(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): _, _, H, W = x.shape _assert( H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).", ) _assert( W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).", ) x = self.proj( rearrange( x, "b c (h p1) (w p2) -> b h w (c p1 p2)", p1=self.patch_size[0], p2=self.patch_size[1], ) ) if self.flatten: x = rearrange(x, "b h w c -> b (h w) c") x = self.norm(x) return x ================================================ FILE: flash_attn/layers/rotary.py ================================================ # Copyright (c) 2025, Tri Dao import math from functools import partial from typing import Optional, Tuple, Union import torch from torch import Tensor from einops import rearrange, repeat from flash_attn.ops.triton.rotary import apply_rotary def rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, ) class ApplyRotaryEmb(torch.autograd.Function): @staticmethod def forward( ctx, x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, Tensor] = 0, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): out = apply_rotary( x, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=interleaved, inplace=inplace, ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.inplace = inplace ctx.max_seqlen = max_seqlen return out if not inplace else x @staticmethod def backward(ctx, do): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: cos, sin, cu_seqlens = ctx.saved_tensors dx = apply_rotary( do, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=ctx.inplace, conjugate=True, ) return dx, None, None, None, None, None, None, None def apply_rotary_emb( x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, Tensor] = 0, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None, ): """ Arguments: x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) cos, sin: (seqlen_rotary, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. cu_seqlens: (batch + 1,) or None max_seqlen: int Return: out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ return ApplyRotaryEmb.apply( x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen ) # For backward compatibility apply_rotary_emb_func = apply_rotary_emb def _apply_rotary_emb_qkv( qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False, inplace=False, conjugate=False, seqlen_offsets: Union[int, Tensor] = 0, num_heads_q: Optional[int] = None, ): apply_rotary_fn = partial( apply_rotary, interleaved=interleaved, inplace=inplace, conjugate=conjugate, seqlen_offsets=seqlen_offsets ) if cos_k is None and sin_k is None and qkv.is_contiguous(): # Call 1 kernel instead of 2 kernels # We need qkv to be contiguous so that when we reshape to combine (3, nheads) # dimensions, we get the same tensor if qkv.dim() == 5: batch, seqlen, three, nheads, headdim = qkv.shape assert three == 3 # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) qk = apply_rotary_fn(qk, cos, sin) else: assert qkv.dim() == 4 assert num_heads_q is not None num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert qkv.shape[2] == num_heads_q + 2 * num_heads_k qk = qkv[:, :, :num_heads_q + num_heads_k] qk = apply_rotary_fn(qk, cos, sin) if not inplace: if qkv.dim() == 5: qkv = torch.cat([rearrange(qk, "b s (t h) d -> b s t h d", t=2), qkv[:, :, 2:]], dim=2) else: qkv = torch.cat([qk, qkv[:, :, num_heads_q + num_heads_k :]], dim=2) else: cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k if qkv.dim() == 5: batch, seqlen, three, nheads, headdim = qkv.shape assert three == 3 q, k = qkv[:, :, 0], qkv[:, :, 1] else: assert qkv.dim() == 4 assert num_heads_q is not None num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert qkv.shape[2] == num_heads_q + 2 * num_heads_k q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k] q = apply_rotary_fn(q, cos, sin) k = apply_rotary_fn(k, cos_k, sin_k) if not inplace: if qkv.dim() == 5: qkv = torch.stack([q, k, qkv[:, :, 2]], dim=2) else: qkv = torch.cat([q, k, qkv[:, :, num_heads_q + num_heads_k:]], dim=2) return qkv class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, num_heads_q: Optional[int] = None, ): # apply_rotary_emb_qkv_inplace( qkv = _apply_rotary_emb_qkv( qkv, cos, sin, cos_k, sin_k, interleaved=interleaved, inplace=True, seqlen_offsets=seqlen_offsets, num_heads_q=num_heads_q, ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cos_k, sin_k) ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.num_heads_q = num_heads_q return qkv @staticmethod def backward(ctx, dqkv): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors else: cos, sin, cos_k, sin_k = ctx.saved_tensors dqkv = _apply_rotary_emb_qkv( dqkv, cos, sin, cos_k, sin_k, interleaved=ctx.interleaved, inplace=True, seqlen_offsets=seqlen_offsets, num_heads_q=ctx.num_heads_q, conjugate=True, ) return dqkv, None, None, None, None, None, None, None def apply_rotary_emb_qkv_( qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, num_heads_q: Optional[int] = None, ): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim). If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), then num_heads_q must be provided. cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. Most commonly used in inference when we have KV cache. Return: qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of Q and K. """ return ApplyRotaryEmbQKV_.apply( qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q ) class ApplyRotaryEmbKV_(torch.autograd.Function): @staticmethod def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): batch, seqlen, two, nheads, headdim = kv.shape assert two == 2 k = kv[:, :, 0] apply_rotary( k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved return kv @staticmethod def backward(ctx, dkv): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, seqlen_offsets = ctx.saved_tensors else: cos, sin = ctx.saved_tensors apply_rotary( dkv[:, :, 0], cos, sin, seqlen_offsets=seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True, ) return dkv, None, None, None, None apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply def apply_rotary_emb_kv_( kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, ): """ Arguments: kv: (batch_size, seqlen, 2, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. Most commonly used in inference when we have KV cache. Return: kv: (batch_size, seqlen, 2, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of K. """ return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) class RotaryEmbedding(torch.nn.Module): """ The rotary position embeddings from RoFormer_ (Su et. al). A crucial insight from the method is that the query and keys are transformed by rotation matrices which depend on the relative positions. Other implementations are available in the Rotary Transformer repo_ and in GPT-NeoX_, GPT-NeoX was an inspiration .. _RoFormer: https://arxiv.org/abs/2104.09864 .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ def __init__( self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None, ): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). """ super().__init__() self.dim = dim self.base = float(base) # Generate and save the inverse frequency buffer (non trainable) inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.interleaved = interleaved self.scale_base = scale_base scale = ( (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base is not None else None ) self.register_buffer("scale", scale, persistent=False) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None def _compute_inv_freq(self, device=None): return 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) ) def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): # Reset the tables if the sequence length has changed, # if we're on a new device (possibly due to tracing for instance), # or if we're switching from inference mode to training if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) ): self._seq_len_cached = seqlen # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # And the output of arange can be quite large, so bf16 would lose a lot of precision. t = torch.arange(seqlen, device=device, dtype=torch.float32) # We want fp32 here as well since inv_freq will be multiplied with t, and the output # will be large. Having it in bf16 will lose a lot of precision and cause the # cos & sin output to change significantly. # We want to recompute self.inv_freq if it was not loaded in fp32 if self.inv_freq.dtype != torch.float32: inv_freq = self._compute_inv_freq(device=device) else: inv_freq = self.inv_freq # Don't do einsum, it converts fp32 to bf16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward( self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None, num_heads_q: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim). If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA), then num_heads_q must be provided. kv: (batch, seqlen, 2, nheads, headdim) seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one should pass in max_seqlen, which will update the cos / sin cache up to that length. Apply rotary embedding *inplace* to qkv and / or kv. """ seqlen = qkv.shape[1] if max_seqlen is not None: self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) elif isinstance(seqlen_offset, int): self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: return apply_rotary_emb_qkv_( qkv, self._cos_cached, self._sin_cached, self._cos_k_cached if self.scale is not None else None, self._sin_k_cached if self.scale is not None else None, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, num_heads_q=num_heads_q, ) else: q = qkv q = apply_rotary_emb_func( q, self._cos_cached, self._sin_cached, interleaved=self.interleaved, inplace=True, seqlen_offsets=seqlen_offset, ) kv = apply_rotary_emb_kv_( kv, self._cos_cached if self.scale is None else self._cos_k_cached, self._sin_cached if self.scale is None else self._sin_k_cached, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, ) return q, kv ================================================ FILE: flash_attn/losses/__init__.py ================================================ ================================================ FILE: flash_attn/losses/cross_entropy.py ================================================ # Copyright (c) 2024, Tri Dao. import torch import torch.nn as nn from flash_attn.ops.triton.cross_entropy import cross_entropy_loss class CrossEntropyLoss(nn.Module): def __init__( self, ignore_index=-100, reduction="mean", label_smoothing=0.0, logit_scale=1.0, lse_square_scale=0.0, inplace_backward=False, process_group=None, return_z_loss=False, ): """ Arguments: ignore_index: int. If labels == ignore_index, the loss is set to 0.0. label_smoothing: float lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for one part of the vocab. The loss will be aggregated across processes. return_z_loss: bool. If True, we return the component of the loss contributed by the lse_square_scale value. This value is only for logging and does not support backprop. """ super().__init__() if reduction not in ["mean", "none", "sum"]: raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") self.ignore_index = ignore_index self.reduction = reduction self.label_smoothing = label_smoothing self.logit_scale = logit_scale self.lse_square_scale = lse_square_scale self.inplace_backward = inplace_backward self.process_group = process_group self.return_z_loss = return_z_loss def forward(self, input, target, precomputed_lse=None): """ Arguments: input: (batch, vocab_size) target: (batch,) Returns: losses: (batch,) if reduction is 'none', else (1,), dtype float z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) """ assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" loss, z_loss = cross_entropy_loss( input, target, precomputed_lse=precomputed_lse, label_smoothing=self.label_smoothing, logit_scale=self.logit_scale, lse_square_scale=self.lse_square_scale, ignore_index=self.ignore_index, inplace_backward=self.inplace_backward, process_group=self.process_group, ) if self.reduction == "mean": loss = loss.sum() / (target != self.ignore_index).sum() elif self.reduction == "sum": loss = loss.sum() else: loss = loss if not self.return_z_loss: return loss if self.reduction == "mean": z_loss = z_loss.sum() / (target != self.ignore_index).sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: z_loss = z_loss return loss, z_loss ================================================ FILE: flash_attn/models/__init__.py ================================================ ================================================ FILE: flash_attn/models/baichuan.py ================================================ # Copyright (c) 2023, GGGGGGXY, Tri Dao. import math import json import re from pathlib import Path from collections import OrderedDict import torch import torch.nn.functional as F from einops import rearrange from transformers import GPT2Config, AutoConfig, PretrainedConfig def remap_state_dict_hf_baichuan(state_dict, config): def key_mapping_layers(key): return re.sub(r"^model.", "transformer.", key) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # Word embedding def key_mapping_emb(key): return re.sub( r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key, ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict[ "transformer.embeddings.word_embeddings.weight" ] else: output_embeddings = state_dict.pop("lm_head.weight") # Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings # differently. vocab_size = ( math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) key = re.sub( r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key, ) key = re.sub( r"^transformer.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for l in range(config.n_layer): w1 = state_dict.pop(f"transformer.layers.{l}.mlp.gate_proj.weight") w3 = state_dict.pop(f"transformer.layers.{l}.mlp.up_proj.weight") # Our ordering is different state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat( [w3, w1], dim=0 ) def key_mapping_mlp(key): return re.sub( r"^transformer.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key, ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention def key_mapping_attn(key): key = re.sub( r"^transformer.layers.(\d+).self_attn.W_pack.", r"transformer.layers.\1.mixer.Wqkv.", key, ) key = re.sub( r"^transformer.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key, ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) for l in range(config.n_layer): # pop rotary_emb.inv_freq from state dict state_dict.pop(f"transformer.layers.{l}.self_attn.rotary_emb.inv_freq", None) return state_dict def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config: # HACK: the config doesn't have say whether it's rotary or alibi. # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi). # HACK: the config doesn't have say whether it uses norm head. # So we have to infer from the vocab size # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head). use_rotary = baichuan_config.hidden_size < 5000 return GPT2Config( vocab_size=baichuan_config.vocab_size, n_positions=0, # No absolute position embedding n_embd=baichuan_config.hidden_size, n_layer=baichuan_config.num_hidden_layers, n_head=baichuan_config.num_attention_heads, n_inner=baichuan_config.intermediate_size, activation_function="swiglu", # Hardcode since HF calls it 'silu' # baichuan doesn't have dropout, idk if it's because they only release the inference code resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, layer_norm_epsilon=baichuan_config.rms_norm_eps, initializer_range=baichuan_config.initializer_range, bos_token_id=baichuan_config.bos_token_id, eos_token_id=baichuan_config.eos_token_id, # These are new arguments not in the original GPT2Config pad_token_id=baichuan_config.pad_token_id, # Idk if this does anything rms_norm=True, rotary_emb_fraction=1.0 if use_rotary else 0.0, rotary_emb_interleaved=False, use_alibi=not use_rotary, use_flash_attn=not use_rotary, # Alibi code path requires flash_attn tie_word_embeddings=False, norm_head=baichuan_config.vocab_size > 70000, qkv_proj_bias=False, out_proj_bias=False, mlp_fc1_bias=False, mlp_fc2_bias=False, ) ================================================ FILE: flash_attn/models/bert.py ================================================ # Copyright (c) 2022, Tri Dao. # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py import logging import re from collections import OrderedDict from collections.abc import Sequence from functools import partial from typing import Any, Mapping import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers import BertConfig, PretrainedConfig from transformers.models.bert.modeling_bert import ( BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput, ) from flash_attn.bert_padding import ( index_first_axis, index_first_axis_residual, pad_input, unpad_input, ) from flash_attn.modules.block import Block from flash_attn.modules.embedding import BertEmbeddings from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import FusedMLP, Mlp from flash_attn.utils.pretrained import state_dict_from_pretrained try: from flash_attn.ops.fused_dense import FusedDense except ImportError: FusedDense = None try: from flash_attn.ops.triton.layer_norm import layer_norm_fn except ImportError: layer_norm_fn = None try: from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = None logger = logging.getLogger(__name__) def create_mixer_cls(config, cross_attn=False, return_residual=False): use_flash_attn = getattr(config, "use_flash_attn", False) fused_bias_fc = getattr(config, "fused_bias_fc", False) rotary_kwargs = {} if config.position_embedding_type == "rotary": rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) mixer_cls = partial( MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, dropout=config.attention_probs_dropout_prob, causal=False, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, return_residual=return_residual, **rotary_kwargs, ) return mixer_cls def create_mlp_cls(config, layer_idx=None, return_residual=False): inner_dim = config.intermediate_size fused_mlp = getattr(config, "fused_mlp", False) if fused_mlp: assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], ( "fused_mlp only " "supports approximate gelu" ) if not fused_mlp: approximate = ( "tanh" if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" ) mlp_cls = partial( Mlp, hidden_features=inner_dim, activation=partial(F.gelu, approximate=approximate), return_residual=return_residual, ) else: if FusedMLP is None: raise ImportError("fused_dense is not installed") mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer if isinstance(mlp_checkpoint_lvl, Sequence): assert layer_idx is not None mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_cls = partial( FusedMLP, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual, ) return mlp_cls def create_block(config, layer_idx=None): last_layer_subset = getattr(config, "last_layer_subset", False) cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1 # TD [2022-12-19]: For cross attention (last layer), we actually want to return the # residual x_kv, not residual x. But it's annoying to change the API (and it only affects # one layer) so we just choose not to return residual in this case. return_residual = not cross_attn mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual) mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) block = Block( config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, prenorm=False, resid_dropout1=config.hidden_dropout_prob, resid_dropout2=config.hidden_dropout_prob, fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), return_residual=return_residual, ) return block # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 def _init_weights(module, initializer_range=0.02): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if module.padding_idx is not None: nn.init.zeros_(module.weight[module.padding_idx]) class BertEncoder(nn.Module): def __init__(self, config: BertConfig): super().__init__() self.use_flash_attn = getattr(config, "use_flash_attn", False) self.layers = nn.ModuleList( [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)] ) def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): """If subset_mask is not None, we only want output for the subset of the sequence. This means that we only compute the last layer output for these tokens. subset_mask: (batch, seqlen), dtype=torch.bool """ if key_padding_mask is None or not self.use_flash_attn: mixer_kwargs = ( {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None ) for layer in self.layers: hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) if subset_mask is not None: hidden_states = hidden_states[subset_mask] else: batch, seqlen = hidden_states.shape[:2] hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input( hidden_states, key_padding_mask ) mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} if subset_mask is None: for layer in self.layers: hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) hidden_states = pad_input(hidden_states, indices, batch, seqlen) else: for layer in self.layers[:-1]: hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) if key_padding_mask is not None: subset_idx = torch.nonzero( subset_mask[key_padding_mask], as_tuple=False ).flatten() subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) subset_cu_seqlens = F.pad( torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0) ) else: subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) subset_cu_seqlens = F.pad( torch.cumsum(subset_seqlens, dim=0, dtype=torch.int32), (1, 0) ) hidden_states_subset, hidden_states = index_first_axis_residual( hidden_states, subset_idx ) # It's ok to set max_seqlen_q to be much larger mixer_kwargs = { "x_kv": hidden_states, "cu_seqlens": subset_cu_seqlens, "max_seqlen": max_seqlen_in_batch, "cu_seqlens_k": cu_seqlens, "max_seqlen_k": max_seqlen_in_batch, } hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs) return hidden_states class BertPooler(nn.Module): def __init__(self, config): super().__init__() fused_bias_fc = getattr(config, "fused_bias_fc", False) if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.dense = linear_cls(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states, pool=True): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] if pool else hidden_states pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() fused_bias_fc = getattr(config, "fused_bias_fc", False) if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln and layer_norm_fn is None: raise ImportError("Triton is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.dense = linear_cls(config.hidden_size, config.hidden_size) approximate = ( "tanh" if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none" ) self.transform_act_fn = nn.GELU(approximate=approximate) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) if not self.fused_dropout_add_ln: hidden_states = self.layer_norm(hidden_states) else: hidden_states = layer_norm_fn( hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps ) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() fused_bias_fc = getattr(config, "fused_bias_fc", False) if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertPreTrainingHeads(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, sequence_output, pooled_output): prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class BertPreTrainedModel(nn.Module): """An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, *inputs, **kwargs): super().__init__() if not isinstance(config, BertConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ ) ) self.config = config @classmethod def from_pretrained(cls, model_name, config, *inputs, **kwargs): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ # Instantiate model. model = cls(config, *inputs, **kwargs) load_return = model.load_state_dict( remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False ) logger.info(load_return) return model class BertModel(BertPreTrainedModel): def __init__(self, config: BertConfig, add_pooling_layer=True): super().__init__(config) self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) if config.vocab_size % self.pad_vocab_size_multiple != 0: config.vocab_size += self.pad_vocab_size_multiple - ( config.vocab_size % self.pad_vocab_size_multiple ) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln and layer_norm_fn is None: raise ImportError("Triton is not installed") assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] self.embeddings = BertEmbeddings( config.hidden_size, config.vocab_size, config.max_position_embeddings, config.type_vocab_size, padding_idx=config.pad_token_id, ) self.emb_drop = nn.Dropout(config.hidden_dropout_prob) self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.apply(partial(_init_weights, initializer_range=config.initializer_range)) def forward( self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, masked_tokens_mask=None, ): """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining), we only want the output for the masked tokens. This means that we only compute the last layer output for these tokens. masked_tokens_mask: (batch, seqlen), dtype=torch.bool """ hidden_states = self.embeddings( input_ids, position_ids=position_ids, token_type_ids=token_type_ids ) # TD [2022-12:18]: Don't need to force residual in fp32 # BERT puts embedding LayerNorm before embedding dropout. if not self.fused_dropout_add_ln: hidden_states = self.emb_ln(hidden_states) else: hidden_states = layer_norm_fn( hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps ) hidden_states = self.emb_drop(hidden_states) if masked_tokens_mask is not None: batch_size, seqlen = input_ids.shape[:2] # We also need the first column for the CLS token first_col_mask = torch.zeros( batch_size, seqlen, dtype=torch.bool, device=input_ids.device ) first_col_mask[:, 0] = True subset_mask = masked_tokens_mask | first_col_mask else: subset_mask = None sequence_output = self.encoder( hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask ) if masked_tokens_mask is None: pooled_output = self.pooler(sequence_output) if self.pooler is not None else None else: # TD [2022-03-01]: the indexing here is very tricky. if attention_mask is not None: subset_idx = subset_mask[attention_mask] pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]] sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]] else: pool_input = sequence_output[first_col_mask[subset_mask]] sequence_output = sequence_output[masked_tokens_mask[subset_mask]] pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, ) class BertForPreTraining(BertPreTrainedModel): def __init__(self, config: BertConfig): super().__init__(config) # If dense_seq_output, we only need to pass the hidden states for the masked out tokens # (around 15%) to the classifier heads. self.dense_seq_output = getattr(config, "dense_seq_output", False) # If last_layer_subset, we only need the compute the last layer for a subset of tokens # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). self.last_layer_subset = getattr(config, "last_layer_subset", False) if self.last_layer_subset: assert self.dense_seq_output, "last_layer_subset requires dense_seq_output" use_xentropy = getattr(config, "use_xentropy", False) if use_xentropy and CrossEntropyLoss is None: raise ImportError("xentropy_cuda is not installed") loss_cls = ( nn.CrossEntropyLoss if not use_xentropy else partial(CrossEntropyLoss, inplace_backward=True) ) self.bert = BertModel(config) self.cls = BertPreTrainingHeads(config) self.mlm_loss = loss_cls(ignore_index=0) self.nsp_loss = loss_cls(ignore_index=-1) # Initialize weights and apply final processing self.apply(partial(_init_weights, initializer_range=config.initializer_range)) self.tie_weights() def tie_weights(self): self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight def forward( self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, labels=None, next_sentence_label=None, ): """ If labels are provided, they must be 0 for masked out tokens (as specified in the attention mask). Outputs: if `labels` and `next_sentence_label` are not `None`: Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss. if `labels` or `next_sentence_label` is `None`: Outputs a tuple comprising - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - the next sentence classification logits of shape [batch_size, 2]. """ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None outputs = self.bert( input_ids, position_ids=position_ids, token_type_ids=token_type_ids, attention_mask=attention_mask.bool() if attention_mask is not None else None, masked_tokens_mask=masked_tokens_mask, ) sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output if self.dense_seq_output and labels is not None: masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() if not self.last_layer_subset: sequence_output = index_first_axis( rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx ) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) total_loss = None if labels is not None and next_sentence_label is not None: if ( self.dense_seq_output and labels is not None ): # prediction_scores are already flattened masked_lm_loss = self.mlm_loss( prediction_scores, labels.flatten()[masked_token_idx] ) else: masked_lm_loss = self.mlm_loss( rearrange(prediction_scores, "... v -> (...) v"), rearrange(labels, "... -> (...)"), ) next_sentence_loss = self.nsp_loss( rearrange(seq_relationship_score, "... t -> (...) t"), rearrange(next_sentence_label, "... -> (...)"), ) total_loss = masked_lm_loss.float() + next_sentence_loss.float() return BertForPreTrainingOutput( loss=total_loss, prediction_logits=prediction_scores, seq_relationship_logits=seq_relationship_score, ) def remap_state_dict(state_dict, config: PretrainedConfig): """ Map the state_dict of a Huggingface BERT model to be flash_attn compatible. """ # LayerNorm def key_mapping_ln_gamma_beta(key): key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) return key state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) # Layers def key_mapping_layers(key): return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key) key = re.sub( r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)", r"bert.encoder.layers.\1.norm1.\2", key, ) key = re.sub( r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)", r"bert.encoder.layers.\1.norm2.\2", key, ) key = re.sub( r"^cls.predictions.transform.LayerNorm.(weight|bias)", r"cls.predictions.transform.layer_norm.\1", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): key = re.sub( r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)", r"bert.encoder.layers.\1.mlp.fc1.\2", key, ) key = re.sub( r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)", r"bert.encoder.layers.\1.mlp.fc2.\2", key, ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention last_layer_subset = getattr(config, "last_layer_subset", False) for d in range(config.num_hidden_layers): Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight") Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight") Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight") bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias") bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias") bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias") if not (last_layer_subset and d == config.num_hidden_layers - 1): state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat( [Wq, Wk, Wv], dim=0 ) state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) else: state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0) state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0) def key_mapping_attn(key): return re.sub( r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)", r"bert.encoder.layers.\1.mixer.out_proj.\2", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) def key_mapping_decoder_bias(key): return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key) state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) # Word embedding pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) if pad_vocab_size_multiple > 1: word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] state_dict["bert.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) ) decoder_weight = state_dict["cls.predictions.decoder.weight"] state_dict["cls.predictions.decoder.weight"] = F.pad( decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) ) # If the vocab was padded, we want to set the decoder bias for those padded indices to be # strongly negative (i.e. the decoder shouldn't predict those indices). # TD [2022-05-09]: I don't think it affects the MLPerf training. decoder_bias = state_dict["cls.predictions.decoder.bias"] state_dict["cls.predictions.decoder.bias"] = F.pad( decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 ) return state_dict def inv_remap_state_dict(state_dict, config: PretrainedConfig): """ Map the state_dict of a flash_attn model to be Huggingface BERT compatible. This function is meant to be the inverse of remap_state_dict. """ # Word embedding pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) if pad_vocab_size_multiple > 1: word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"] decoder_weight = state_dict["cls.predictions.decoder.weight"] decoder_bias = state_dict["cls.predictions.decoder.bias"] # unpad embeddings state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[ : config.orig_vocab_size, : ] state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :] state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size] for d in range(config.num_hidden_layers): last_layer_subset = getattr(config, "last_layer_subset", False) if not last_layer_subset or d != (config.num_hidden_layers - 1): Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight") Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias") state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[ : Wqkv_weights.shape[0] // 3, : ] state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, : ] state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[ 2 * Wqkv_weights.shape[0] // 3 :, : ] state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[ : Wqkv_biases.shape[0] // 3 ] state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3 ] state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[ 2 * Wqkv_biases.shape[0] // 3 : ] else: Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight") Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight") Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias") Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias") state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[ : Wkv_weights.shape[0] // 2, : ] state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[ Wkv_weights.shape[0] // 2 :, : ] state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[ : Wkv_biases.shape[0] // 2 ] state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[ Wkv_biases.shape[0] // 2 : ] def inv_key_mapping_ln(key): key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key) key = re.sub( r"bert.encoder.layers.(\d+).norm1.(weight|bias)", r"bert.encoder.layers.\1.attention.output.LayerNorm.\2", key, ) key = re.sub( r"bert.encoder.layers.(\d+).norm2.(weight|bias)", r"bert.encoder.layers.\1.output.LayerNorm.\2", key, ) key = re.sub( r"cls.predictions.transform.layer_norm.(weight|bias)", r"cls.predictions.transform.LayerNorm.\1", key, ) return key def inv_key_mapping_ln_gamma_beta(key): key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key) key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key) return key def inv_key_mapping_layers(key): return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key) def inv_key_mapping_mlp(key): key = re.sub( r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)", r"bert.encoder.layer.\1.intermediate.dense.\2", key, ) key = re.sub( r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)", r"bert.encoder.layer.\1.output.dense.\2", key, ) return key def inv_key_mapping_attn(key): return re.sub( r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)", r"bert.encoder.layer.\1.attention.output.dense.\2", key, ) def inv_key_mapping_decoder_bias(key): return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key) state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items()) state_dict = OrderedDict( (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items() ) state_dict = OrderedDict( (inv_key_mapping_layers(key), value) for key, value in state_dict.items() ) state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items()) state_dict = OrderedDict( (inv_key_mapping_attn(key), value) for key, value in state_dict.items() ) state_dict = OrderedDict( (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items() ) return state_dict ================================================ FILE: flash_attn/models/bigcode.py ================================================ import math import re from collections import OrderedDict import torch import torch.nn.functional as F from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): """ Map the state_dict of a Huggingface BigCode model to be flash_attn compatible. """ # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.wte.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) key = re.sub( r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) def key_mapping_mlp(key): key = re.sub( r"^transformer.h.(\d+).mlp.c_fc.weight", r"transformer.layers.\1.mlp.fc1.weight", key, ) key = re.sub( r"^transformer.h.(\d+).mlp.c_proj.weight", r"transformer.layers.\1.mlp.fc2.weight", key, ) key = re.sub( r"^transformer.h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key, ) key = re.sub( r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key, ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # TODO: add support for multi-head attention assert config.multi_query, "Only multi-query attention is supported" # Attention for d in range(config.num_hidden_layers): embed_dim = config.n_embd head_dim = embed_dim // config.n_head c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") # with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim) # see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112 # see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183 # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0) # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) k = torch.tile(k, (config.n_head, 1)) v = torch.tile(v, (config.n_head, 1)) state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0) # same deal with the bias c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias") # ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim) q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0) # duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim) k = torch.tile(k, (config.n_head,)) v = torch.tile(v, (config.n_head,)) state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0) def key_mapping_attn(key): key = re.sub( r"^transformer.h.(\d+).attn.c_proj.weight", r"transformer.layers.\1.mixer.out_proj.weight", key, ) key = re.sub( r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key, ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): """ Map the state_dict of a flash_attn model to be Huggingface BigCode compatible. This function is meant to be the inverse of remap_state_dict_hf_bigcode. """ # Word embedding and position embeddings def inv_key_mapping_pos_emb(key): return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key) state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") word_embeddings = word_embeddings[:, : config.vocab_size] state_dict["transformer.wte.weight"] = word_embeddings state_dict["lm_head.weight"] = word_embeddings # LayerNorm def inv_key_mapping_ln(key): key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) key = re.sub( r"^transformer.layers.(\d+).norm(1|2).(weight|bias)", r"transformer.h.\1.ln_\2.\3", key, ) return key state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items()) # MLPs def inv_key_mapping_mlp(key): key = re.sub( r"^transformer.layers.(\d+).mlp.fc1.weight", r"transformer.h.\1.mlp.c_fc.weight", key, ) key = re.sub( r"^transformer.layers.(\d+).mlp.fc2.weight", r"transformer.h.\1.mlp.c_proj.weight", key, ) key = re.sub( r"^transformer.layers.(\d+).mlp.fc1.bias", r"transformer.h.\1.mlp.c_fc.bias", key, ) key = re.sub( r"^transformer.layers.(\d+).mlp.fc2.bias", r"transformer.h.\1.mlp.c_proj.bias", key, ) return key state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for d in range(config.num_hidden_layers): embed_dim = config.n_embd head_dim = embed_dim // config.n_head Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") q, k, v = torch.split( Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 ) c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight # Same deal with the bias Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") q, k, v = torch.split( Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0 ) c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0) state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias def inv_key_mapping_attn(key): key = re.sub( r"^transformer.layers.(\d+).mixer.out_proj.weight", r"transformer.h.\1.attn.c_proj.weight", key, ) key = re.sub( r"^transformer.layers.(\d+).mixer.out_proj.bias", r"transformer.h.\1.attn.c_proj.bias", key, ) return key state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config: return GPT2Config( activation_function=bigcode_config.activation_function, attn_pdrop=bigcode_config.attn_pdrop, bos_token_id=bigcode_config.bos_token_id, embd_pdrop=bigcode_config.embd_pdrop, eos_token_id=bigcode_config.eos_token_id, initializer_range=bigcode_config.initializer_range, layer_norm_epsilon=bigcode_config.layer_norm_epsilon, max_batch_size=bigcode_config.max_batch_size, max_sequence_length=bigcode_config.max_sequence_length, model_type=bigcode_config.model_type, multi_query=bigcode_config.multi_query, n_embd=bigcode_config.n_embd, n_head=bigcode_config.n_head, n_inner=bigcode_config.n_inner, n_layer=bigcode_config.n_layer, n_positions=bigcode_config.n_positions, resid_pdrop=bigcode_config.resid_pdrop, scale_attn_weights=bigcode_config.scale_attn_weights, summary_activation=bigcode_config.summary_activation, summary_first_dropout=bigcode_config.summary_first_dropout, summary_proj_to_labels=bigcode_config.summary_proj_to_labels, summary_type=bigcode_config.summary_type, summary_use_proj=bigcode_config.summary_use_proj, use_cache=bigcode_config.use_cache, vocab_size=bigcode_config.vocab_size, ) ================================================ FILE: flash_attn/models/btlm.py ================================================ # Copyright (c) 2023, Tri Dao. import math import json import re from pathlib import Path from collections import OrderedDict import torch import torch.nn.functional as F from einops import rearrange from transformers import GPT2Config, AutoConfig, PretrainedConfig def remap_state_dict_hf_btlm(state_dict, config): # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) if "transformer.wpe.weight" in state_dict: state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.wte.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for d in range(config.num_hidden_layers): W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight") W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight") state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0) b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias") b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias") state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0) W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight") state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() def key_mapping_mlp(key): key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for d in range(config.num_hidden_layers): Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight") state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes def key_mapping_attn(key): key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) key = re.sub( r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config: return GPT2Config( vocab_size=btlm_config.vocab_size, n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions, n_embd=btlm_config.hidden_size, n_layer=btlm_config.num_hidden_layers, n_head=btlm_config.num_attention_heads, n_inner=btlm_config.n_inner, activation_function=btlm_config.activation_function, resid_pdrop=btlm_config.resid_pdrop, embd_pdrop=btlm_config.embd_pdrop, attn_pdrop=btlm_config.attn_pdrop, layer_norm_epsilon=btlm_config.layer_norm_epsilon, initializer_range=btlm_config.initializer_range, bos_token_id=btlm_config.bos_token_id, eos_token_id=btlm_config.eos_token_id, # These are new arguments not in the original GPT2Config use_alibi=btlm_config.position_embedding_type == "alibi", use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn mup_width_scale=btlm_config.mup_width_scale, mup_embeddings_multiplier=btlm_config.mup_embeddings_scale, mup_output_multiplier=btlm_config.mup_output_alpha, mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d, mlp_multiple_of=1, ) ================================================ FILE: flash_attn/models/falcon.py ================================================ # Copyright (c) 2023, Tri Dao. import math import re from collections import OrderedDict import torch import torch.nn.functional as F from einops import rearrange from transformers import FalconConfig, GPT2Config def remap_state_dict_hf_falcon(state_dict, config): def key_mapping_layers(key): return re.sub(r"^transformer.h.", "transformer.layers.", key) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # Word embedding def key_mapping_emb(key): return re.sub( r"^transformer.word_embeddings.", "transformer.embeddings.word_embeddings.", key ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("lm_head.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) output_embeddings_bias = state_dict.pop("lm_head.bias") state_dict["lm_head.bias"] = F.pad( output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) ) # LayerNorm def key_mapping_ln(key): key = re.sub( r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key ) key = re.sub( r"^transformer.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key, ) key = re.sub(r"^transformer.layers.(\d+).ln_attn.", r"transformer.layers.\1.norm1.", key) key = re.sub(r"^transformer.layers.(\d+).ln_mlp.", r"transformer.layers.\1.norm2.", key) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): key = re.sub( r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key ) key = re.sub( r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) def key_mapping_attn(key): key = re.sub( r"^transformer.layers.(\d+).self_attention.query_key_value.", r"transformer.layers.\1.mixer.Wqkv.", key, ) key = re.sub( r"^transformer.layers.(\d+).self_attention.dense.", r"transformer.layers.\1.mixer.out_proj.", key, ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) n_head = config.n_head n_head_kv = getattr(config, "n_head_kv", 1) headdim = config.hidden_size // n_head for l in range(config.n_layer): # The weights are stored in a different layout compared to our implementation Wqkv = rearrange( state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight"), "(group ratio headdim) ... -> group ratio headdim ...", ratio=n_head // n_head_kv + 2, headdim=headdim, ) Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...") Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...") Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) return state_dict def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config: # The 40b config uses "n_head_kv" instead of "num_kv_heads" n_head_kv = getattr( falcon_config, "n_head_kv", 1 if getattr(falcon_config, "multi_query", False) else falcon_config.n_head, ) # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config. # So we have to infer it from the number of heads in the key/value block parallel_block_tied_norm = n_head_kv == 1 return GPT2Config( vocab_size=falcon_config.vocab_size, n_positions=0, # No absolute position embedding n_embd=falcon_config.hidden_size, n_layer=falcon_config.n_layer, n_head=falcon_config.n_head, n_inner=falcon_config.hidden_size * 4, activation_function="gelu", resid_pdrop=falcon_config.hidden_dropout, embd_pdrop=0.0, # There doesn't seem to be any embedding dropout attn_pdrop=falcon_config.attention_dropout, layer_norm_epsilon=falcon_config.layer_norm_epsilon, initializer_range=falcon_config.initializer_range, bos_token_id=falcon_config.bos_token_id, eos_token_id=falcon_config.eos_token_id, # These are new arguments not in the original GPT2Config parallel_block=falcon_config.parallel_attn, n_head_kv=n_head_kv, parallel_block_tied_norm=parallel_block_tied_norm, rotary_emb_fraction=1.0, rotary_emb_interleaved=False, tie_word_embeddings=True, qkv_proj_bias=falcon_config.bias, out_proj_bias=falcon_config.bias, mlp_fc1_bias=falcon_config.bias, mlp_fc2_bias=falcon_config.bias, lm_head_bias=False, ) ================================================ FILE: flash_attn/models/gpt.py ================================================ # Copyright (c) 2024, Tri Dao. import logging import math import re from collections import OrderedDict, namedtuple from collections.abc import Sequence from functools import partial from typing import Dict, List import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers import GPT2Config from flash_attn.models.bigcode import remap_state_dict_hf_bigcode from flash_attn.models.falcon import remap_state_dict_hf_falcon from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox from flash_attn.models.gptj import remap_state_dict_hf_gptj from flash_attn.models.llama import remap_state_dict_hf_llama from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import ( FusedMLP, GatedMlp, Mlp, ParallelFusedMLP, ParallelGatedMlp, ParallelMLP, ) from flash_attn.ops.activations import sqrelu_fwd from flash_attn.utils.distributed import ( all_gather, all_gather_raw, get_dim_for_local_rank, sync_shared_params, ) from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.pretrained import state_dict_from_pretrained try: from flash_attn.ops.fused_dense import ColumnParallelLinear except ImportError: ColumnParallelLinear = None try: from flash_attn.ops.triton.mlp import FusedDenseSqreluDense except ImportError: FusedDenseSqreluDense = None try: from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm except ImportError: layer_norm_fn, RMSNorm = None, None logger = logging.getLogger(__name__) def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) if config.scale_attn_by_inverse_layer_idx: assert layer_idx is not None softmax_scale /= float(layer_idx + 1) dwconv = getattr(config, "attn_dwconv", False) if dwconv: assert process_group is None, "TensorParallel MHA does not support dwconv yet" qkv_proj_bias = getattr(config, "qkv_proj_bias", True) out_proj_bias = getattr(config, "out_proj_bias", True) rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) use_alibi = getattr(config, "use_alibi", False) window_size = getattr(config, "window_size", (-1, -1)) use_flash_attn = getattr(config, "use_flash_attn", False) fused_bias_fc = getattr(config, "fused_bias_fc", False) if not fused_bias_fc: assert process_group is None, "TensorParallel MHA requires fused_bias_fc" mha_cls = MHA if process_group is None else ParallelMHA serial_kwargs = ( {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} ) parallel_kwargs = ( { "process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True), } if process_group is not None else {} ) num_heads_kv = getattr(config, "n_head_kv", None) mixer_cls = partial( mha_cls, num_heads=config.num_attention_heads, num_heads_kv=num_heads_kv, qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, dropout=config.attn_pdrop, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base, rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_interleaved=rotary_emb_interleaved, use_alibi=use_alibi, window_size=window_size, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, **factory_kwargs, ) return mixer_cls def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) fused_mlp = getattr(config, "fused_mlp", False) if fused_mlp: assert config.activation_function in [ "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu", ] fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) if fused_dense_sqrelu_dense: assert config.activation_function == "sqrelu", ( "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu" ) assert not (fused_dense_sqrelu_dense and fused_mlp) if not fused_mlp and not fused_dense_sqrelu_dense: assert config.activation_function in [ "gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu", "glu", "swiglu", "geglu", ] if config.activation_function in ["glu", "swiglu", "geglu"]: activation = ( F.sigmoid if config.activation_function == "glu" else (F.silu if config.activation_function == "swiglu" else F.gelu) ) mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp parallel_kwargs = ( { "process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True), } if process_group is not None else {} ) mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) mlp_cls = partial( mlp_cls, hidden_features=config.n_inner, activation=activation, bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, multiple_of=mlp_multiple_of, **parallel_kwargs, **factory_kwargs, ) else: if config.activation_function == "relu": activation = partial(F.relu, inplace=True) elif config.activation_function == "sqrelu": activation = sqrelu_fwd else: approximate = ( "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] else "none" ) activation = partial(F.gelu, approximate=approximate) mlp_cls = Mlp if process_group is None else ParallelMLP parallel_kwargs = ( { "process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True), } if process_group is not None else {} ) mlp_cls = partial( mlp_cls, hidden_features=config.n_inner, activation=activation, bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **parallel_kwargs, **factory_kwargs, ) else: mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer if isinstance(mlp_checkpoint_lvl, Sequence): assert layer_idx is not None mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] if fused_mlp: if FusedMLP is None: raise ImportError("fused_dense is not installed") activation = ( "gelu_approx" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"] else config.activation_function ) mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP parallel_kwargs = ( { "process_group": process_group, "sequence_parallel": getattr(config, "sequence_parallel", True), } if process_group is not None else {} ) mlp_cls = partial( mlp_cls, hidden_features=config.n_inner, activation=activation, checkpoint_lvl=mlp_checkpoint_lvl, bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **parallel_kwargs, **factory_kwargs, ) elif fused_dense_sqrelu_dense: if process_group is not None: assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" assert FusedDenseSqreluDense is not None mlp_cls = partial( FusedDenseSqreluDense, hidden_features=config.n_inner, checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs, ) else: raise RuntimeError("MLP type not supported") return mlp_cls def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} sequence_parallel = getattr(config, "sequence_parallel", True) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) use_rms_norm = getattr(config, "rms_norm", False) norm_cls = partial( nn.LayerNorm if not use_rms_norm else RMSNorm, eps=config.layer_norm_epsilon, **factory_kwargs, ) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable residual_in_fp32 = getattr(config, "residual_in_fp32", False) resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop prenorm = getattr(config, "prenorm", True) parallel_block = getattr(config, "parallel_block", False) if not parallel_block: block = Block( config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, mark_shared_params=process_group is not None, ) else: assert prenorm block = ParallelBlock( config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, tied_norm=getattr(config, "parallel_block_tied_norm", False), fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, mark_shared_params=process_group is not None, ) block.layer_idx = layer_idx return block class GPTPreTrainedModel(nn.Module): """An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, *inputs, **kwargs): super().__init__() if not isinstance(config, GPT2Config): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ ) ) self.config = config @classmethod def from_pretrained( cls, model_name, config, *args, strict=True, device=None, dtype=None, world_size=1, rank=0, **kwargs, ): """ Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. """ # Instantiate model. model = cls(config, *args, device=device, dtype=dtype, **kwargs) # Load state_dict in cpu because we already initialized the model in GPU, and we don't # want extra stuff taking up more GPU memory state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) if model_name.startswith("gpt2"): state_dict = remap_state_dict_hf_gpt2(state_dict, config) elif model_name.startswith("facebook/opt"): state_dict = remap_state_dict_hf_opt(state_dict, config) elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith( "togethercomputer/GPT-JT-" ): state_dict = remap_state_dict_hf_gptj(state_dict, config) elif ( model_name.startswith("EleutherAI/gpt-neox-") or model_name.startswith("EleutherAI/pythia-") or model_name.startswith("togethercomputer/RedPajama-INCITE-") ): state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) elif model_name.startswith("tiiuae/falcon-"): state_dict = remap_state_dict_hf_falcon(state_dict, config) elif model_name.startswith("meta-llama/Llama-"): state_dict = remap_state_dict_hf_llama(state_dict, config) elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"): state_dict = remap_state_dict_hf_bigcode(state_dict, config) else: raise NotImplementedError(f"Model {model_name} not supported") if world_size > 1: state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) load_return = model.load_state_dict(state_dict, strict=strict) logger.info(load_return) return model # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 def _init_weights( module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True ): mup_init_scale = math.sqrt(mup_width_scale) if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) optim_cfg = getattr(module.weight, "_optim", {}) optim_cfg.update({"lr_multiplier": mup_width_scale}) setattr(module.weight, "_optim", optim_cfg) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block nn.init.normal_( p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer) ) class GPTModel(GPTPreTrainedModel): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): super().__init__(config) factory_kwargs = {"device": device, "dtype": dtype} self.process_group = process_group self.sequence_parallel = getattr(config, "sequence_parallel", True) assert config.activation_function in [ "gelu", "gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh", "relu", "sqrelu", "glu", "swiglu", "geglu", ] pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple ) self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) # These 2 options are for OPT-350m self.prenorm = getattr(config, "prenorm", True) use_rms_norm = getattr(config, "rms_norm", False) word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) # For GPT-J, GPT-NeoX self.parallel_block = getattr(config, "parallel_block", False) if process_group is None: self.embeddings = GPT2Embeddings( config.hidden_size, vocab_size, config.max_position_embeddings, word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs, ) else: self.embeddings = ParallelGPT2Embeddings( config.hidden_size, vocab_size, config.max_position_embeddings, process_group=process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs, ) # We change the order of dropout, residual and layer norm: # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and # the main branch (output of MLP). The model definition is unchanged, but the mapping of the # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. self.layers = nn.ModuleList( [ create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) for i in range(config.num_hidden_layers) ] ) rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0) if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache for layer in self.layers[1:]: layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln: if layer_norm_fn is None: raise ImportError("Triton is not installed") if self.prenorm: self.drop_f = nn.Dropout(config.resid_pdrop) norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm self.ln_f = norm_cls( config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs ) if process_group is not None: for p in self.ln_f.parameters(): # Mark the norm parameters as "shared_params" so that we sync their values at init. p._shared_params = True # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. if self.sequence_parallel: p._sequence_parallel = True self.apply( partial( _init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range, mup_width_scale=getattr(config, "mup_width_scale", 1.0), ) ) self.tie_weights() def tie_weights(self): if self.process_group is not None: sync_shared_params(self, self.process_group) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return { i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) for i, layer in enumerate(self.layers) } def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. # Only the attention layers need to know the seqlen. embedding_kwargs = ( {"combine_batch_seqlen_dim": True} if self.process_group is not None and self.sequence_parallel else {} ) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) if self.embeddings_multiplier != 1.0: hidden_states = hidden_states * self.embeddings_multiplier if self.parallel_block: hidden_states2 = None residual = None mixer_kwargs = ( {"seqlen": input_ids.shape[1]} if self.process_group is not None and self.sequence_parallel else {} ) if inference_params is not None: mixer_kwargs["inference_params"] = inference_params for layer in self.layers: if self.prenorm: if not self.parallel_block: hidden_states, residual = layer( hidden_states, residual, mixer_kwargs=mixer_kwargs ) else: hidden_states, hidden_states2, residual = layer( hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs ) else: hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) if self.prenorm: if not self.fused_dropout_add_ln: dropped = self.drop_f(hidden_states) if not self.parallel_block: residual = (dropped + residual) if residual is not None else dropped else: dropped2 = self.drop_f(hidden_states2) residual = ( (residual + dropped + dropped2) if residual is not None else dropped + dropped2 ) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual hidden_states = layer_norm_fn( hidden_states, self.ln_f.weight, self.ln_f.bias, residual=residual, x1=None if not self.parallel_block else hidden_states2, eps=self.ln_f.eps, dropout_p=self.drop_f.p if self.training else 0.0, prenorm=False, is_rms_norm=isinstance(self.ln_f, RMSNorm) ) return hidden_states class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__(config) self.process_group = process_group self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) lm_head_bias = getattr(config, "lm_head_bias", False) pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple ) # This option is for OPT-350m word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim if word_embed_proj_dim is not None: self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) else: self.project_out = None mup_width_scale = getattr(config, "mup_width_scale", 1.0) mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) self.output_scale = mup_output_multiplier * mup_width_scale if process_group is None: self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError("fused_dense_lib is not installed") self.lm_head = ColumnParallelLinear( embed_dim, vocab_size, process_group, bias=lm_head_bias, sequence_parallel=getattr(config, "sequence_parallel", True), **factory_kwargs, ) self.norm_head = getattr(config, "norm_head", False) # Initialize weights and apply final processing self.apply( partial( _init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range, mup_width_scale=mup_width_scale, ) ) self.tie_weights() def tie_weights(self): if self.tie_word_embeddings: self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight if self.process_group is not None: sync_shared_params(self, self.process_group) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.transformer.allocate_inference_cache( batch_size, max_seqlen, dtype=dtype, **kwargs ) def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): """ input_ids: (batch, seqlen) int tensor inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 num_last_tokens: if > 0, only return the logits for the last n tokens """ assert ( input_ids.ndim == 2 ), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}" b, slen = input_ids.shape hidden_states = self.transformer( input_ids, position_ids=position_ids, inference_params=inference_params ) if inference_params is not None: assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) if self.output_scale != 1.0: hidden_states = hidden_states * self.output_scale if not self.norm_head: lm_logits = self.lm_head(hidden_states) else: lm_head_weight = F.normalize(self.lm_head.weight) if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: hidden_states = all_gather(hidden_states, self.lm_head.process_group) lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b) CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) def load_state_dict(self, state_dict, strict=True): # Remapping from our checkpoints that used a different ordering of layers in the block # Previous: Attn / MLP -> Dropout -> Add -> LN # Current: Dropout -> Add -> LN -> Attn / MLP if "transformer.ln_0.weight" in state_dict: n_layers = len(self.transformer.layers) ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight") ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") state_dict["transformer.ln_f.weight"] = ln_weight state_dict["transformer.ln_f.bias"] = ln_bias for l in reversed(range(n_layers)): ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias if l > 0: ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight") ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias ln_weight = state_dict.pop("transformer.ln_0.weight") ln_bias = state_dict.pop("transformer.ln_0.bias") state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias return super().load_state_dict(state_dict, strict=strict) def shard_state_dict_tp(state_dict, config, world_size, rank): """Convert the state_dict of a standard GPT model to the state_dict of a GPT model with tensor parallel. This function modifies state_dict in place. """ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple assert vocab_size % world_size == 0 assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size assert inner_dim % world_size == 0 n_head = config.n_head n_head_kv = getattr(config, "n_head_kv", n_head) embed_dim = config.hidden_size head_dim = embed_dim // n_head def shard_first_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size state_dict[key] = x[rank * dim : (rank + 1) * dim] def shard_last_dim(state_dict, key, multiple_of=1): if key in state_dict: x = state_dict[key] dim_each_rank = [ get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of) for local_rank in range(world_size) ] beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1)) state_dict[key] = x[..., beg:end] def shard_gatedmlp_fc1_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size // 2 state_dict[key] = rearrange( rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], "two o ... -> (two o) ...", ) def shard_qkv_headdim(state_dict, key): if key in state_dict: n_head_each_rank = [ get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size) ] n_head_kv_each_rank = [ get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size) ] beg_n_head = sum(n_head_each_rank[:rank]) end_n_head = sum(n_head_each_rank[: rank + 1]) beg_n_head_kv = sum(n_head_kv_each_rank[:rank]) end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1]) if n_head_kv == n_head: x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) state_dict[key] = rearrange( x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ...", ) else: x = rearrange( state_dict[key], "(nheadqkv headdim) ... -> nheadqkv headdim ...", nheadqkv=n_head + 2 * n_head_kv, ) state_dict[key] = rearrange( torch.cat( [ x[beg_n_head:end_n_head], x[n_head + beg_n_head_kv : n_head + end_n_head_kv], x[ n_head + n_head_kv + beg_n_head_kv : n_head + n_head_kv + end_n_head_kv ], ], dim=0, ), "nheadqkv headdim ... -> (nheadqkv headdim) ...", ) shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") if "lm_head.weight" in state_dict: shard_first_dim(state_dict, "lm_head.weight") if "transformer.embeddings.position_embeddings.weight" in state_dict: shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") for i in range(config.num_hidden_layers): shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") shard_last_dim( state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim ) if rank != 0: state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) if config.activation_function in ["glu", "swiglu", "geglu"]: shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") else: shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") if rank != 0: state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None) return state_dict def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): """Convert the list of sharded state_dict of a GPT model with tensor parallel to the state_dict of a standard GPT model. This function is meant to be the "reverse" of shard_state_dict_tp. Precondition: - state_dicts should be ordered in the same way as the shards were created. """ world_size = len(state_dicts) keys = state_dicts[0].keys() pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple assert vocab_size % world_size == 0 assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size assert inner_dim % world_size == 0 assert config.hidden_size % config.n_head == 0 headdim = config.hidden_size // config.n_head # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. # vocab_size // world_size coordinates are nonzero. def combine_word_embeddings(state_dicts, state_dict, key): dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1 state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) def combine_dim(state_dicts, state_dict, key, dim=-1): if key in state_dict: state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) def combine_qkv_headdim(state_dicts, state_dict, key): n_head = config.n_head n_head_kv = getattr(config, "n_head_kv", n_head) if key in state_dict: if n_head_kv == n_head: xs = [ rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts ] state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") else: n_head_each_rank = [ get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size) ] n_head_kv_each_rank = [ get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size) ] xs = [ rearrange( s[key], "(nheadqkv headdim) ... -> nheadqkv headdim ...", nheadqkv=rank_n_head + 2 * rank_n_head_kv, headdim=headdim, ) for s, rank_n_head, rank_n_head_kv in zip( state_dicts, n_head_each_rank, n_head_kv_each_rank ) ] wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0) wk = torch.cat( [ x[ n_head_each_rank[rank] : n_head_each_rank[rank] + n_head_kv_each_rank[rank] ] for rank, x in enumerate(xs) ], dim=0, ) wv = torch.cat( [ x[n_head_each_rank[rank] + n_head_kv_each_rank[rank] :] for rank, x in enumerate(xs) ], dim=0, ) wqkv = torch.cat( [wq, wk, wv], dim=0, ) state_dict[key] = rearrange( wqkv, "nheadqkv headdim ... -> (nheadqkv headdim) ...", ) def combine_gated_mlp(state_dicts, state_dict, key): if key in state_dict: xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts] state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...") state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace combine_word_embeddings( state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight" ) if "lm_head.weight" in state_dict: combine_word_embeddings(state_dicts, state_dict, "lm_head.weight") if "transformer.embeddings.position_embeddings.weight" in state_dict: combine_dim( state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1 ) mlp_combine_fn = ( combine_gated_mlp if config.activation_function in ["glu", "swiglu", "geglu"] else partial(combine_dim, dim=0) ) for i in range(config.num_hidden_layers): combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1) mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0) combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1) return state_dict def remap_state_dict_hf_gpt2(state_dict, config): # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("wte.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for d in range(config.num_hidden_layers): W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() def key_mapping_mlp(key): key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key) key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for d in range(config.num_hidden_layers): state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() def key_mapping_attn(key): key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) key = re.sub( r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def remap_state_dict_megatron(state_dict, config): def key_mapping_transformer(key): key = re.sub(r"^language_model.encoder.", "transformer.", key) key = re.sub(r"^language_model.", "transformer.", key) return key state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) # Word embedding and position embedding def key_mapping_pos_emb(key): return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key) key = re.sub( r"^transformer.layers.(\d+).input_layernorm.(weight|bias)", r"transformer.layers.\1.norm1.\2", key, ) key = re.sub( r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)", r"transformer.layers.\1.norm2.\2", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): key = re.sub( r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)", r"transformer.layers.\1.mlp.fc1.\2", key, ) key = re.sub( r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)", r"transformer.layers.\1.mlp.fc2.\2", key, ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention def key_mapping_attn(key): key = re.sub( r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq", r"transformer.layers.\1.mixer.rotary_emb.inv_freq", key, ) key = re.sub( r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)", r"transformer.layers.\1.mixer.Wqkv.\2", key, ) key = re.sub( r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)", r"transformer.layers.\1.mixer.out_proj.\2", key, ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim) headdim = config.hidden_size // config.num_attention_heads for d in range(config.num_hidden_layers): Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange( Wqkv, "(nheads three headdim) ... -> (three nheads headdim) ...", three=3, headdim=headdim, ) bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange( bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim ) return state_dict ================================================ FILE: flash_attn/models/gpt_neox.py ================================================ # Copyright (c) 2023, Tri Dao. import math import re from collections import OrderedDict import torch import torch.nn.functional as F from einops import rearrange from transformers import GPT2Config, GPTNeoXConfig def remap_state_dict_hf_gpt_neox(state_dict, config): def key_mapping_layers(key): return re.sub(r"^gpt_neox.", "transformer.", key) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # Word embedding def key_mapping_emb(key): return re.sub(r"^transformer.embed_in.", "transformer.embeddings.word_embeddings.", key) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) if getattr(config, "tie_word_embeddings", False): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("embed_out.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) key = re.sub( r"^transformer.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key ) key = re.sub( r"^transformer.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): key = re.sub( r"^transformer.layers.(\d+).mlp.dense_h_to_4h.", r"transformer.layers.\1.mlp.fc1.", key ) key = re.sub( r"^transformer.layers.(\d+).mlp.dense_4h_to_h.", r"transformer.layers.\1.mlp.fc2.", key ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for l in range(config.n_layer): # We don't store these biases state_dict.pop(f"transformer.layers.{l}.attention.bias") state_dict.pop(f"transformer.layers.{l}.attention.masked_bias") # We don't store these state_dict.pop(f"transformer.layers.{l}.attention.rotary_emb.inv_freq", None) # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim) headdim = config.hidden_size // config.num_attention_heads Wqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = rearrange( Wqkv, "(nheads three headdim) ... -> (three nheads headdim) ...", three=3, headdim=headdim, ) bqkv = state_dict.pop(f"transformer.layers.{l}.attention.query_key_value.bias") state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = rearrange( bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim ) def key_mapping_attn(key): key = re.sub( r"^transformer.layers.(\d+).attention.dense.", r"transformer.layers.\1.mixer.out_proj.", key, ) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: assert gpt_neox_config.rotary_emb_base == 10000 return GPT2Config( vocab_size=gpt_neox_config.vocab_size, n_positions=0, # No absolute position embedding n_embd=gpt_neox_config.hidden_size, n_layer=gpt_neox_config.num_hidden_layers, n_head=gpt_neox_config.num_attention_heads, n_inner=gpt_neox_config.intermediate_size, activation_function=gpt_neox_config.hidden_act, resid_pdrop=0.0, # No dropout embd_pdrop=0.0, attn_pdrop=0.0, layer_norm_epsilon=gpt_neox_config.layer_norm_eps, initializer_range=gpt_neox_config.initializer_range, bos_token_id=gpt_neox_config.bos_token_id, eos_token_id=gpt_neox_config.eos_token_id, # These are new arguments not in the original GPT2Config prenorm=True, parallel_block=gpt_neox_config.use_parallel_residual, parallel_block_tied_norm=False, rotary_emb_fraction=gpt_neox_config.rotary_pct, tie_word_embeddings=gpt_neox_config.tie_word_embeddings, ) ================================================ FILE: flash_attn/models/gptj.py ================================================ # Copyright (c) 2023, Tri Dao. import math import re from collections import OrderedDict import torch import torch.nn.functional as F from transformers import GPT2Config, GPTJConfig def remap_state_dict_hf_gptj(state_dict, config): def key_mapping_layers(key): return re.sub(r"^transformer.h.", "transformer.layers.", key) state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # Word embedding def key_mapping_emb(key): return re.sub(r"^transformer.wte.", "transformer.embeddings.word_embeddings.", key) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("lm_head.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) output_embeddings_bias = state_dict.pop("lm_head.bias") state_dict["lm_head.bias"] = F.pad( output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) ) # LayerNorm def key_mapping_ln(key): return re.sub(r"^transformer.layers.(\d+).ln_1.", r"transformer.layers.\1.norm1.", key) state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): key = re.sub( r"^transformer.layers.(\d+).mlp.fc_in.", r"transformer.layers.\1.mlp.fc1.", key ) key = re.sub( r"^transformer.layers.(\d+).mlp.fc_out.", r"transformer.layers.\1.mlp.fc2.", key ) return key state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for l in range(config.n_layer): Wq = state_dict.pop(f"transformer.layers.{l}.attn.q_proj.weight") Wk = state_dict.pop(f"transformer.layers.{l}.attn.k_proj.weight") Wv = state_dict.pop(f"transformer.layers.{l}.attn.v_proj.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) # We don't store these biases state_dict.pop(f"transformer.layers.{l}.attn.bias") state_dict.pop(f"transformer.layers.{l}.attn.masked_bias") def key_mapping_attn(key): return re.sub( r"^transformer.layers.(\d+).attn.out_proj.", r"transformer.layers.\1.mixer.out_proj.", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: headdim = gptj_config.n_embd // gptj_config.n_head return GPT2Config( vocab_size=gptj_config.vocab_size, n_positions=0, # No absolute position embedding n_embd=gptj_config.n_embd, n_layer=gptj_config.n_layer, n_head=gptj_config.n_head, n_inner=gptj_config.n_inner, activation_function=gptj_config.activation_function, resid_pdrop=gptj_config.resid_pdrop, embd_pdrop=gptj_config.embd_pdrop, attn_pdrop=gptj_config.attn_pdrop, layer_norm_epsilon=gptj_config.layer_norm_epsilon, initializer_range=gptj_config.initializer_range, bos_token_id=gptj_config.bos_token_id, eos_token_id=gptj_config.eos_token_id, # These are new arguments not in the original GPT2Config prenorm=True, parallel_block=True, parallel_block_tied_norm=True, rotary_emb_fraction=gptj_config.rotary_dim / headdim, rotary_emb_interleaved=True, tie_word_embeddings=False, qkv_proj_bias=False, out_proj_bias=False, lm_head_bias=True, ) ================================================ FILE: flash_attn/models/llama.py ================================================ # Copyright (c) 2023, Tri Dao. import json import math import os import re from collections import OrderedDict from pathlib import Path from typing import Dict, List, Union import torch import torch.nn.functional as F from sentencepiece import SentencePieceProcessor from transformers import GPT2Config, LlamaConfig from einops import rearrange def remap_state_dict_meta_llama( state_dict: Dict[str, torch.Tensor], config: GPT2Config ) -> Dict[str, torch.Tensor]: """Convert the state_dict in Meta format to standard GPT format. This function modifies state_dict in place. """ def key_mapping_layers(key): return f"transformer.{key}" if not key.startswith("output.") else key state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) # Word embedding def key_mapping_emb(key): return re.sub( r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("output.weight") # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings # differently. vocab_size = ( math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) key = re.sub( r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key, ) key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for l in range(config.n_layer): w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight") w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight") # Our ordering is different state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) def key_mapping_mlp(key): return re.sub( r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key, ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for l in range(config.n_layer): Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight") Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight") Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) # We don't store these state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) def key_mapping_attn(key): return re.sub( r"^transformer.layers.(\d+).attention.wo.", r"transformer.layers.\1.mixer.out_proj.", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) state_dict.pop("transformer.rope.freqs", None) return state_dict def remap_state_dict_hf_llama( state_dict: Dict[str, torch.Tensor], config: GPT2Config ) -> Dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. This function modifies state_dict in place. """ # Embedding def key_mapping_emb(key): return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) # LM head if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] else: output_embeddings = state_dict.pop("lm_head.weight") # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings # differently. vocab_size = ( math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) # It's possible that vocab_size is padded to be a multiple of 8, for example. state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) # MLP for l in range(config.n_layer): # Fusing weights this way based on difference in the following: # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220 # https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115 w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight") w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight") state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) def key_mapping_mlp(key): return re.sub( r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key, ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) key = re.sub( r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key, ) key = re.sub( r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) def inv_permute(w): # Inverse of permute implemented in: # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 return rearrange( w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 ) # Attention for l in range(config.n_layer): Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 ) # We don't store these state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) def key_mapping_attn(key): return re.sub( r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def inv_remap_state_dict_hf_llama( state_dict: Dict[str, torch.Tensor], config: GPT2Config ) -> Dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. This function is meant to be the inverse of remap_state_dict_hf_llama, up to a multiplier pad in the embedding and lm_head. That is if the original embedding isn't a multiple of pad_vocab_size_multiple, then inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict. This function modifies state_dict in place. """ # Embedding def key_mapping_emb(key): return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop("model.embed_tokens.weight") pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = ( math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["model.embed_tokens.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) # LM head if getattr(config, "tie_word_embeddings"): state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] else: output_embeddings = state_dict.pop("lm_head.weight") vocab_size = ( math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple ) state_dict["lm_head.weight"] = F.pad( output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) ) # MLP for l in range(config.n_layer): w3, w1 = torch.chunk( state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0 ) state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1 state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 def key_mapping_mlp(key): return re.sub( r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key, ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) key = re.sub( r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key, ) key = re.sub( r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) def permute(w): return rearrange( w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2 ) n_head = config.n_head n_head_kv = getattr(config, "n_head_kv", n_head) embed_dim = config.hidden_size head_dim = embed_dim // n_head q_dim = n_head * head_dim k_dim = v_dim = n_head_kv * head_dim # Attention for l in range(config.n_layer): Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight") Wq = Wqkv[:q_dim] Wk = Wqkv[q_dim : q_dim + k_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) def key_mapping_attn(key): return re.sub( r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def config_from_meta_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str ) -> LlamaConfig: """Load a LlamaConfig from a checkpoint path.""" with open(Path(checkpoint_path) / model_name / "params.json") as f: params = json.load(f) config = LlamaConfig( hidden_size=params["dim"], intermediate_size=None, num_attention_heads=params["n_heads"], num_hidden_layers=params["n_layers"], rms_norm_eps=params["norm_eps"], num_key_value_heads=params.get("n_kv_heads", None), ) multiple_of = params.get("multiple_of", 1) ffn_dim_multiplier = params.get("ffn_dim_multiplier", None) # Compute the hidden dimension of the MLP # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224 intermediate_size = 4 * config.hidden_size # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199 intermediate_size = int(2 * intermediate_size / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: intermediate_size = int(ffn_dim_multiplier * intermediate_size) intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) config.intermediate_size = intermediate_size if "rope_theta" in params: config.rotary_emb_base = params["rope_theta"] config.vocab_size = 32000 # some CodeLLaMa have vocab_size 32000, some 32016 # Sadly it's not specified in the `params.json` file :( tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model" if tokenizer.is_file(): config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size() return config def config_from_hf_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str ) -> LlamaConfig: return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json") def config_from_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta" ) -> LlamaConfig: if checkpoint_format == "meta": return config_from_meta_checkpoint(checkpoint_path, model_name) else: return config_from_hf_checkpoint(checkpoint_path, model_name) def state_dicts_from_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str ) -> List[dict]: # Need to sort, otherwise we mess up the ordering and the weights are wrong return [ torch.load(path, map_location="cpu") for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth")) ] def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: return GPT2Config( vocab_size=llama_config.vocab_size, n_positions=0, # No absolute position embedding n_embd=llama_config.hidden_size, n_layer=llama_config.num_hidden_layers, n_head=llama_config.num_attention_heads, n_inner=llama_config.intermediate_size, activation_function="swiglu", # Hardcode since HF calls it 'silu' # Llama doesn't have dropout, idk if it's because they only release the inference code resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, layer_norm_epsilon=llama_config.rms_norm_eps, initializer_range=llama_config.initializer_range, bos_token_id=llama_config.bos_token_id, eos_token_id=llama_config.eos_token_id, # These are new arguments not in the original GPT2Config pad_token_id=llama_config.pad_token_id, # Idk if this does anything rms_norm=True, rotary_emb_fraction=1.0, rotary_emb_interleaved=True, tie_word_embeddings=False, qkv_proj_bias=False, out_proj_bias=False, mlp_fc1_bias=False, mlp_fc2_bias=False, rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0), n_head_kv=llama_config.num_key_value_heads, ) ================================================ FILE: flash_attn/models/opt.py ================================================ # Copyright (c) 2023, Tri Dao. import math import re from collections import OrderedDict import torch import torch.nn.functional as F from transformers import GPT2Config, OPTConfig def remap_state_dict_hf_opt(state_dict, config): def key_mapping_model(key): key = re.sub(r"^model.decoder.", "transformer.", key) # The OPT-350m model uses '^decoder' instead of '^model.decoder' key = re.sub(r"^decoder.", "transformer.", key) return key state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) # Word embedding and position embedding def key_mapping_emb(key): key = re.sub(r"^transformer.embed_tokens.", "transformer.embeddings.word_embeddings.", key) # The OPT-350m model uses has project_in and project_out key = re.sub(r"^transformer.project_in.", "transformer.embeddings.project_in.", key) key = re.sub(r"^transformer.project_out.", "project_out.", key) key = re.sub( r"^transformer.embed_positions.", "transformer.embeddings.position_embeddings.", key ) return key state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) # OPT uses the first 2 indices of pos_emb for padding tokens pos_embeddings = state_dict.pop("transformer.embeddings.position_embeddings.weight") state_dict["transformer.embeddings.position_embeddings.weight"] = pos_embeddings[2:] word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.final_layer_norm.", r"transformer.ln_f.", key) # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' key = re.sub(r"^transformer.layer_norm.", r"transformer.ln_f.", key) key = re.sub( r"^transformer.layers.(\d+).self_attn_layer_norm.", r"transformer.layers.\1.norm1.", key ) key = re.sub( r"^transformer.layers.(\d+).final_layer_norm.", r"transformer.layers.\1.norm2.", key ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): return re.sub( r"^transformer.layers.(\d+).fc(1|2).", r"transformer.layers.\1.mlp.fc\2.", key ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for l in range(config.n_layer): Wq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.weight") Wk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.weight") Wv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.weight") bq = state_dict.pop(f"transformer.layers.{l}.self_attn.q_proj.bias") bk = state_dict.pop(f"transformer.layers.{l}.self_attn.k_proj.bias") bv = state_dict.pop(f"transformer.layers.{l}.self_attn.v_proj.bias") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0) state_dict[f"transformer.layers.{l}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0) def key_mapping_attn(key): return re.sub( r"^transformer.layers.(\d+).self_attn.out_proj.", r"transformer.layers.\1.mixer.out_proj.", key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: assert opt_config.layerdrop == 0.0 assert opt_config.layer_norm_elementwise_affine word_embed_proj_dim = ( None if opt_config.word_embed_proj_dim == opt_config.hidden_size else opt_config.word_embed_proj_dim ) return GPT2Config( vocab_size=opt_config.vocab_size, n_positions=opt_config.max_position_embeddings, n_embd=opt_config.hidden_size, n_layer=opt_config.num_hidden_layers, n_head=opt_config.num_attention_heads, n_inner=opt_config.ffn_dim, activation_function=opt_config.activation_function, resid_pdrop=opt_config.dropout, # HF's implementation of OPT doesn't seem to have embedding dropout embd_pdrop=opt_config.dropout, attn_pdrop=opt_config.attention_dropout, initializer_range=opt_config.init_std, bos_token_id=opt_config.bos_token_id, eos_token_id=opt_config.eos_token_id, # These are new arguments not in the original GPT2Config prenorm=opt_config.do_layer_norm_before, word_embed_proj_dim=word_embed_proj_dim, ) ================================================ FILE: flash_attn/models/vit.py ================================================ # Copyright (c) 2022, Tri Dao. # Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py import math import re from collections import OrderedDict from copy import deepcopy from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from timm.models.helpers import named_apply from torch.nn.init import trunc_normal_ from torchvision.ops import StochasticDepth from flash_attn.layers.patch_embed import PatchEmbed from flash_attn.modules.block import Block from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import FusedMLP, Mlp try: from flash_attn.ops.triton.layer_norm import layer_norm_fn except ImportError: layer_norm_fn = None def create_mixer_cls( num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, cross_attn=False ): mixer_cls = partial( MHA, num_heads=num_heads, cross_attn=cross_attn, qkv_proj_bias=qkv_bias, dropout=attn_drop, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, ) return mixer_cls def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): inner_dim = int(embed_dim * mlp_ratio) if not fused_mlp: mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) else: mlp_cls = partial(FusedMLP, hidden_features=inner_dim) return mlp_cls def create_block( embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False, ): mixer_cls = create_mixer_cls( num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, cross_attn=(last_layer_subset and layer_idx == n_layer - 1), ) mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed block = Block( embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate, drop_path1=drop_path1, drop_path2=drop_path2, fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True, ) return block class VisionTransformer(nn.Module): """Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool="token", embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, init_values=None, class_token=True, no_embed_class=False, pre_norm=False, fc_norm=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, weight_init="", embed_layer=PatchEmbed, norm_layer=None, act_layer=None, use_flash_attn=False, fused_bias_fc=False, fused_mlp=False, fused_dropout_add_ln=False, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head global_pool (str): type of global pooling for final sequence (default: 'token') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values class_token (bool): use class token fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool == "token", "Only support pooling with CLS token" assert class_token assert init_values is None, "LayerScale is not supported yet" assert weight_init == "" assert fc_norm is None # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk assert not pre_norm use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = ( self.embed_dim ) = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.no_embed_class = no_embed_class patch_embed_extra_kwargs = ( {"fused_bias_fc": fused_bias_fc} if embed_layer is PatchEmbed else {} ) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) **patch_embed_extra_kwargs, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule # We change the order of dropout, residual and layer norm: # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and # the main branch (output of MLP). The model definition is unchanged, but the mapping of the # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. self.blocks = nn.ModuleList( [ create_block( embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path1=dpr[i - 1] if i > 0 else 0.0, drop_path2=dpr[i], norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, last_layer_subset=(global_pool == "token"), ) for i in range(depth) ] ) self.dropout = nn.Dropout(p=drop_rate) self.drop_path = StochasticDepth(p=dpr[-1], mode="row") self.norm = norm_layer(embed_dim) self.fused_dropout_add_ln = fused_dropout_add_ln if self.fused_dropout_add_ln and layer_norm_fn is None: raise ImportError("Triton is not installed") # Classifier Head self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.init_weights(weight_init) def init_weights(self, mode=""): assert mode == "" trunc_normal_(self.pos_embed, std=0.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(init_weights_vit_timm, self) def _init_weights(self, m): # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed", "cls_token"} def _pos_embed(self, x): if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + self.pos_embed if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.pos_embed return x def forward_features(self, x, all_tokens=True): """ If all_tokens==False and self.global_pool == 'token', we only return the features for the cls token. """ x = self.patch_embed(x) hidden_states = self._pos_embed(x) residual = None if self.global_pool != "token" or all_tokens: # if True: for block in self.blocks: hidden_states, residual = block(hidden_states, residual) else: for block in self.blocks[:-1]: hidden_states, residual = block(hidden_states, residual) # For the last layer, we only want the 1st token of the output. So we do cross-attention # where the query is the 1st token and the key/value is the whole sequence. hidden_states, residual = self.blocks[-1]( hidden_states, residual, mixer_subset=slice(0, 1) ) if not self.fused_dropout_add_ln: residual = self.drop_path(self.dropout(hidden_states)) + residual hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) else: if self.drop_path.p == 0 or not self.training: rowscale = None else: rowscale = self.drop_path( torch.ones( hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype, ) ) # Set prenorm=False here since we don't need to the residual hidden_states = layer_norm_fn( hidden_states, self.norm.weight, self.norm.bias, residual=residual, eps=self.norm.eps, dropout_p=self.dropout.p if self.training else 0.0, rowscale=rowscale, prenorm=False, ) return hidden_states def forward_head(self, x, pre_logits: bool = False): if self.global_pool: x = x[:, self.num_prefix_tokens :].mean(dim=1) if self.global_pool == "avg" else x[:, 0] return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x, all_tokens=False) x = self.forward_head(x) return x def load_state_dict(self, state_dict, strict=True): patch_embed_weight = state_dict["patch_embed.proj.weight"] if patch_embed_weight.dim() == 4: # convert from Conv2d to Linear state_dict["patch_embed.proj.weight"] = rearrange( patch_embed_weight, "o c h w -> o (c h w)" ) def key_mapping_attn(key): key = re.sub(r"^blocks.(\d+).attn.qkv.", r"blocks.\1.mixer.Wqkv.", key) key = re.sub(r"^blocks.(\d+).attn.proj.", r"blocks.\1.mixer.out_proj.", key) return key state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) n_layer = len(self.blocks) # Convert from Wqkv to Wq and Wkv for cross attention (last layer) if ( self.blocks[-1].mixer.cross_attn and f"blocks.{n_layer - 1}.mixer.Wqkv.weight" in state_dict ): Wqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.weight") bqkv = state_dict.pop(f"blocks.{n_layer - 1}.mixer.Wqkv.bias") state_dict[f"blocks.{n_layer - 1}.mixer.Wq.weight"] = Wqkv[: self.embed_dim] state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.weight"] = Wqkv[self.embed_dim :] state_dict[f"blocks.{n_layer - 1}.mixer.Wq.bias"] = bqkv[: self.embed_dim] state_dict[f"blocks.{n_layer - 1}.mixer.Wkv.bias"] = bqkv[self.embed_dim :] return super().load_state_dict(state_dict, strict=strict) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() def vit_base_patch16_224(pretrained=False, **kwargs): """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ assert not pretrained model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = VisionTransformer(**model_kwargs) return model ================================================ FILE: flash_attn/modules/__init__.py ================================================ ================================================ FILE: flash_attn/modules/block.py ================================================ # Copyright (c) 2024, Tri Dao. from functools import partial from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torchvision.ops import StochasticDepth from flash_attn.modules.mha import MHA from flash_attn.modules.mlp import Mlp try: from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm except ImportError: layer_norm_fn, RMSNorm = None, None class Block(nn.Module): def __init__( self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0.0, resid_dropout2=0.0, drop_path1=0.0, drop_path2=0.0, fused_dropout_add_ln=False, return_residual=False, residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False, ): """ For prenorm=True, this Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both the hidden_states (output of the MLP) and the residual. This is for performance reasons, as we can fuse the dropout, add and LayerNorm. The residual needs to be provided (except for the very first block). For prenorm=False, this Block has the same structure as a regular postnorm Transformer block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ super().__init__() self.prenorm = prenorm self.fused_dropout_add_ln = fused_dropout_add_ln self.return_residual = return_residual self.residual_in_fp32 = residual_in_fp32 if self.residual_in_fp32: assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" if mixer_cls is None: mixer_cls = partial(MHA, num_heads=dim // 64) if mlp_cls is None: mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) self.dropout1 = dropout_cls(resid_dropout1) self.drop_path1 = StochasticDepth(drop_path1, mode="row") self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): self.dropout2 = dropout_cls(resid_dropout2) self.drop_path2 = StochasticDepth(drop_path2, mode="row") self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: assert layer_norm_fn is not None, "Triton is not installed" assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( self.dropout1, nn.Dropout ) # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # then the input to each worker in the tensor parallel group will be different. # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. # For now this is not an issue because we always use sequence_parallel=True during training # and only use sequence_parallel=False during inference. # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. if sequence_parallel: for p in self.norm1.parameters(): p._sequence_parallel = True if hasattr(self, "norm2"): for p in self.norm2.parameters(): p._sequence_parallel = True # Mark the norm parameters as "shared_params" so that we sync their values at init. if mark_shared_params: for p in self.norm1.parameters(): p._shared_params = True if hasattr(self, "norm2"): for p in self.norm2.parameters(): p._shared_params = True def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) def forward( self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None, ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) mixer_subset: for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. """ if self.prenorm: if not self.fused_dropout_add_ln: dropped = self.drop_path1(self.dropout1(hidden_states)) residual = (dropped + residual) if residual is not None else dropped hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: if self.drop_path1.p == 0 or not self.training: rowscale1 = None else: rowscale1 = self.drop_path1( torch.ones( hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype, ) ) hidden_states, residual = layer_norm_fn( hidden_states, self.norm1.weight, self.norm1.bias, residual=residual, eps=self.norm1.eps, dropout_p=self.dropout1.p if self.training else 0.0, rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm1, RMSNorm) ) if mixer_kwargs is None: mixer_kwargs = {} if mixer_subset is not None: mixer_kwargs["mixer_subset"] = mixer_subset hidden_states = self.mixer(hidden_states, **mixer_kwargs) if mixer_subset is not None: residual = residual[:, mixer_subset] if not isinstance(self.mlp, nn.Identity): if not self.fused_dropout_add_ln: dropped = self.drop_path2(self.dropout2(hidden_states)) residual = (dropped + residual) if residual is not None else dropped hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: if self.drop_path2.p == 0 or not self.training: rowscale2 = None else: rowscale2 = self.drop_path2( torch.ones( hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype, ) ) hidden_states, residual = layer_norm_fn( hidden_states, self.norm2.weight, self.norm2.bias, residual=residual, eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0, rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm2, RMSNorm) ) hidden_states = self.mlp(hidden_states) return hidden_states, residual else: assert residual is None mixer_out = self.mixer( hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) ) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out if not self.fused_dropout_add_ln: hidden_states = self.norm1( (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( dtype=self.norm1.weight.dtype ) ) else: if self.drop_path1.p == 0 or not self.training: rowscale1 = None else: rowscale1 = self.drop_path1( torch.ones( mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype ) ) hidden_states = layer_norm_fn( mixer_out, self.norm1.weight, self.norm1.bias, residual=hidden_states, eps=self.norm1.eps, dropout_p=self.dropout1.p if self.training else 0.0, rowscale=rowscale1, prenorm=False, is_rms_norm=isinstance(self.norm1, RMSNorm) ) if not isinstance(self.mlp, nn.Identity): mlp_out = self.mlp(hidden_states) if self.return_residual: # mlp out is actually a pair here mlp_out, hidden_states = mlp_out if not self.fused_dropout_add_ln: hidden_states = self.norm2( (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( dtype=self.norm2.weight.dtype ) ) else: if self.drop_path2.p == 0 or not self.training: rowscale2 = None else: rowscale2 = self.drop_path2( torch.ones( mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype ) ) hidden_states = layer_norm_fn( mlp_out, self.norm2.weight, self.norm2.bias, residual=hidden_states, eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0, rowscale=rowscale2, prenorm=False, is_rms_norm=isinstance(self.norm2, RMSNorm) ) return hidden_states class ParallelBlock(nn.Module): """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, and PaLM. """ def __init__( self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, dropout_cls=nn.Dropout, resid_dropout1=0.0, resid_dropout2=0.0, tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False, ): """ This Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA / MLP -> Dropout -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both the hidden_states (output1 of the MHA / MLP) and the residual. This is for performance reasons, as we can fuse the dropout, add and LayerNorm. The residual needs to be provided (except for the very first block). """ super().__init__() self.tied_norm = tied_norm self.fused_dropout_add_ln = fused_dropout_add_ln self.residual_in_fp32 = residual_in_fp32 if mixer_cls is None: mixer_cls = partial(MHA, num_heads=dim // 64) if mlp_cls is None: mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) self.dropout1 = dropout_cls(resid_dropout1) self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) self.dropout2 = dropout_cls(resid_dropout2) if not self.tied_norm: self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: assert layer_norm_fn is not None, "Triton is not installed" assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( self.dropout1, nn.Dropout ) # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, # then the input to each worker in the tensor parallel group will be different. # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. # For now this is not an issue because we always use sequence_parallel=True during training # and only use sequence_parallel=False during inference. # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. if sequence_parallel: for p in self.norm1.parameters(): p._sequence_parallel = True if hasattr(self, "norm2"): for p in self.norm2.parameters(): p._sequence_parallel = True # Mark the norm parameters as "shared_params" so that we sync their values at init. if mark_shared_params: for p in self.norm1.parameters(): p._shared_params = True if hasattr(self, "norm2"): for p in self.norm2.parameters(): p._shared_params = True def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) def forward( self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, residual: Optional[Tensor] = None, mixer_kwargs=None, ): r"""Pass the input through the encoder layer. Args: hidden_states1: the output of the previous attention (mixer) or embedding layer. hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). residual. """ # TODO: Ideally we should only do the allgather / allreduce once for # the Linear to MLP & Attention if not self.fused_dropout_add_ln: dropped1 = self.dropout1(hidden_states1) # For the very 1st block, we only want 1 dropout, not two different dropouts if hidden_states2 is not None: dropped2 = self.dropout2(hidden_states2) residual = ( (residual + dropped1 + dropped2) if residual is not None else dropped1 + dropped2 ) else: residual = (residual + dropped1) if residual is not None else dropped1 hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) hidden_states2 = ( self.norm2(residual.to(dtype=self.norm2.weight.dtype)) if not self.tied_norm else hidden_states1 ) if self.residual_in_fp32: residual = residual.to(torch.float32) else: weight2, bias2 = ( (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) ) hidden_states1, *rest, residual = layer_norm_fn( hidden_states1, self.norm1.weight, self.norm1.bias, residual=residual, x1=hidden_states2, weight1=weight2, bias1=bias2, eps=self.norm1.eps, dropout_p=self.dropout1.p if self.training else 0.0, prenorm=True, residual_in_fp32=self.residual_in_fp32, is_rms_norm=isinstance(self.norm1, RMSNorm) ) if self.tied_norm: hidden_states2 = hidden_states1 else: hidden_states2, = rest if mixer_kwargs is None: mixer_kwargs = {} hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) hidden_states2 = self.mlp(hidden_states2) return hidden_states1, hidden_states2, residual ================================================ FILE: flash_attn/modules/embedding.py ================================================ # Copyright (c) 2022, Tri Dao. import torch import torch.nn as nn from einops import rearrange from torch import Tensor from flash_attn.utils.distributed import all_reduce, reduce_scatter class GPT2Embeddings(nn.Module): def __init__( self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, word_embed_proj_dim=None, device=None, dtype=None, ): """ If max_position_embeddings <= 0, there's no position embeddings If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension the project up to embed_dim """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if word_embed_proj_dim is None: self.word_embeddings = nn.Embedding( vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs ) self.project_in = None else: self.word_embeddings = nn.Embedding( vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs ) self.project_in = nn.Linear( word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs ) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: self.position_embeddings = nn.Embedding( max_position_embeddings, embed_dim, **factory_kwargs ) def forward(self, input_ids, position_ids=None): """ input_ids: (batch, seqlen) position_ids: (batch, seqlen) """ batch_size, seqlen = input_ids.shape embeddings = self.word_embeddings(input_ids) if self.project_in is not None: embeddings = self.project_in(embeddings) if self.max_position_embeddings > 0: if position_ids is None: position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings return embeddings class BertEmbeddings(nn.Module): def __init__( self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size, padding_idx=None, device=None, dtype=None, ): """ If max_position_embeddings <= 0, there's no position embeddings If type_vocab_size <= 0, there's no token type embeddings """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.word_embeddings = nn.Embedding( vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs ) self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size if self.max_position_embeddings > 0: self.position_embeddings = nn.Embedding( max_position_embeddings, embed_dim, **factory_kwargs ) if self.type_vocab_size > 0: self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) def forward(self, input_ids, position_ids=None, token_type_ids=None): """ input_ids: (batch, seqlen) position_ids: (batch, seqlen) token_type_ids: (batch, seqlen) """ batch_size, seqlen = input_ids.shape embeddings = self.word_embeddings(input_ids) if self.max_position_embeddings > 0: if position_ids is None: position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if self.type_vocab_size > 0: if token_type_ids is None: token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = embeddings + token_type_embeddings return embeddings class VocabParallelEmbedding(nn.Embedding): def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): self.process_group = process_group if process_group is not None: world_size = torch.distributed.get_world_size(process_group) if num_embeddings % world_size != 0: raise ValueError( f"num_embeddings ({num_embeddings}) must be divisible by " f"world_size ({world_size})" ) if world_size > 1 and padding_idx is not None: raise RuntimeError("ParallelEmbedding does not support padding_idx") else: world_size = 1 super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) def forward(self, input: Tensor) -> Tensor: if self.process_group is None: return super().forward(input) else: rank = torch.distributed.get_rank(self.process_group) vocab_size = self.num_embeddings vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size # Create a mask of valid vocab ids (1 means it needs to be masked). input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) input = input - vocab_start_index input[input_ids_mask] = 0 embeddings = super().forward(input) embeddings[input_ids_mask] = 0.0 return embeddings class ColumnParallelEmbedding(nn.Embedding): def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): self.process_group = process_group if process_group is not None: world_size = torch.distributed.get_world_size(process_group) if embedding_dim % world_size != 0: raise ValueError( f"embedding_dim ({embedding_dim}) must be divisible by " f"world_size ({world_size})" ) else: world_size = 1 super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) class ParallelGPT2Embeddings(nn.Module): def __init__( self, embed_dim, vocab_size, max_position_embeddings, process_group, padding_idx=None, sequence_parallel=True, device=None, dtype=None, ): """ If max_position_embeddings <= 0, there's no position embeddings """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.process_group = process_group self.sequence_parallel = sequence_parallel self.word_embeddings = VocabParallelEmbedding( vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, **factory_kwargs, ) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: self.position_embeddings = ColumnParallelEmbedding( max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs ) def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): """ input_ids: (batch, seqlen) position_ids: (batch, seqlen) """ batch_size, seqlen = input_ids.shape world_size = torch.distributed.get_world_size(self.process_group) embeddings = self.word_embeddings(input_ids) if self.max_position_embeddings > 0: if position_ids is None: position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) position_embeddings = self.position_embeddings(position_ids) if world_size <= 1: embeddings = embeddings + position_embeddings else: partition_dim = self.position_embeddings.embedding_dim rank = torch.distributed.get_rank(self.process_group) embeddings[ ..., rank * partition_dim : (rank + 1) * partition_dim ] += position_embeddings if combine_batch_seqlen_dim: embeddings = rearrange(embeddings, "b s d -> (b s) d") reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) ================================================ FILE: flash_attn/modules/mha.py ================================================ # Copyright (c) 2023, Tri Dao. import math from functools import partial import torch import torch.nn as nn from einops import rearrange, repeat from flash_attn.utils.distributed import get_dim_for_local_rank try: from flash_attn import ( flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) except ImportError: flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None flash_attn_with_kvcache = None try: from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: ColumnParallelLinear, RowParallelLinear = None, None try: from flash_attn.layers.rotary import RotaryEmbedding except ImportError: RotaryEmbedding = None # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 def get_alibi_slopes(nheads): def get_slopes_power_of_2(nheads): start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) ratio = start return [start * ratio**i for i in range(nheads)] if math.log2(nheads).is_integer(): return get_slopes_power_of_2(nheads) else: closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) return ( get_slopes_power_of_2(closest_power_of_2) + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] ) class FlashSelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, window_size=(-1, -1), alibi_slopes=None, deterministic=False, ): super().__init__() assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.window_size = window_size self.deterministic = deterministic def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch. causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. Returns: -------- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D). """ assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if self.alibi_slopes is not None: self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) return flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) else: return flash_attn_qkvpacked_func( qkv, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) class FlashCrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, window_size=(-1, -1), deterministic=False, ): super().__init__() assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.window_size = window_size self.deterministic = deterministic def forward( self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None, ): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. max_seqlen: int. Maximum sequence length in the batch of q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_k: int. Maximum sequence length in the batch of k and v. """ assert q.dtype in [torch.float16, torch.bfloat16] assert q.is_cuda and kv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if self.alibi_slopes is not None: self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) assert cu_seqlens_k is not None assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None assert isinstance(max_seqlen_k, int) return flash_attn_varlen_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] return flash_attn_kvpacked_func( q, kv, self.drop.p if self.training else 0.0, causal=causal, softmax_scale=self.softmax_scale, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S) """ batch_size, seqlen = qkv.shape[0], qkv.shape[1] causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu( torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 ) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output class CrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) def forward(self, q, kv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk) """ batch_size, seqlen_q = q.shape[0], q.shape[1] causal = self.causal if causal is None else causal seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] if kv.shape[3] != q.shape[2]: # MQA/GQA kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) k, v = kv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # causal mask needs to take into account the difference between seqlen_q and seqlen_k row_idx = rearrange( torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" ) col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) causal_mask = col_idx > row_idx + sk - seqlen_q scores = scores.masked_fill(causal_mask, -10000.0) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. num_heads, head_dim = kv.shape[-2:] if layer_idx not in inference_params.key_value_memory_dict: kv_cache = torch.empty( inference_params.max_batch_size, inference_params.max_seqlen, 2, num_heads, head_dim, dtype=kv.dtype, device=kv.device, ) inference_params.key_value_memory_dict[layer_idx] = kv_cache else: kv_cache = inference_params.key_value_memory_dict[layer_idx] # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= kv_cache.shape[0] assert sequence_end <= kv_cache.shape[1] assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv return kv_cache[batch_start:batch_end, :sequence_end, ...] class MHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, window_size=(-1, -1), fused_bias_fc=False, use_flash_attn=False, return_residual=False, checkpointing=False, device=None, dtype=None, ) -> None: """ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.cross_attn = cross_attn self.causal = causal self.layer_idx = layer_idx self.dwconv = dwconv self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.return_residual = return_residual self.checkpointing = checkpointing if use_alibi: assert use_flash_attn, "ALiBi code path requires flash_attn" alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) else: alibi_slopes = None if window_size != (-1, -1): assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" self.num_heads = num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) kv_dim = 2 * self.head_dim * self.num_heads_kv if self.rotary_emb_dim > 0: assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rotary_emb_base, scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device, ) inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention ) if not self.cross_attn: self.Wqkv = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: self.Wq = nn.Linear(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) self.Wkv = nn.Linear(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) if self.dwconv: if self.num_heads_kv == self.num_heads: self.dwconv_qkv = nn.Conv1d( qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim ) else: self.dwconv_q = nn.Conv1d( embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim ) self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype device = self.out_proj.weight.device return torch.empty( batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device, ) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert not self.dwconv, "Generation does not support dwconv yet" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx) def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): """ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: rotary_cos, rotary_sin = None, None batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, alibi_slopes=alibi_slopes, ) return context def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if ( inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None or not self.use_flash_attn ): # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) return flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, alibi_slopes=alibi_slopes, ) def forward( self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs, ): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention. max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. mixer_subset: for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if cu_seqlens is not None: assert max_seqlen is not None assert key_padding_mask is None assert self.use_flash_attn assert not self.dwconv assert self.rotary_emb_dim == 0 if key_padding_mask is not None: assert cu_seqlens is None assert max_seqlen is None assert not self.use_flash_attn if inference_params is not None: assert key_padding_mask is None assert cu_seqlens is None and max_seqlen is None assert not self.dwconv kwargs = ( {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} if self.use_flash_attn else {"key_padding_mask": key_padding_mask, **kwargs} ) seqlen_offset = ( 0 if inference_params is None else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) ) rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None qkv = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: qkv = self.rotary_emb( qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: context = self._update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: context = self._apply_rotary_update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: if self.cross_attn: q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) kv = self.Wkv(x_kv if x_kv is not None else x) else: assert self.num_heads_kv != self.num_heads qkv = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) if self.dwconv: q = rearrange( self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() kv = rearrange( self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: q, kv = self.rotary_emb( q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs) else: context = torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, **kwargs ) else: context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) return out if not self.return_residual else (out, x) class ParallelMHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, process_group, num_heads_kv=None, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, window_size=(-1, -1), use_flash_attn=False, checkpointing=False, sequence_parallel=True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.checkpointing = checkpointing self.process_group = process_group self.world_size = process_group.size() self.local_rank = torch.distributed.get_rank(process_group) self.num_heads = num_heads assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" self.num_heads_per_rank = get_dim_for_local_rank( self.num_heads, self.world_size, self.local_rank ) self.num_heads_kv_per_rank = get_dim_for_local_rank( self.num_heads_kv, self.world_size, self.local_rank ) self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) if use_alibi: assert use_flash_attn, "ALiBi code path requires flash_attn" num_heads_local = math.ceil(self.num_heads / self.world_size) alibi_slopes = torch.tensor( get_alibi_slopes(num_heads)[ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local ], device=device, ) else: alibi_slopes = None if window_size != (-1, -1): assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rotary_emb_base, scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device, ) if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError("fused_dense is not installed") self.Wqkv = ColumnParallelLinear( embed_dim, qkv_dim, process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), **factory_kwargs, ) inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention ) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, process_group, bias=out_proj_bias, sequence_parallel=sequence_parallel, multiple_of=self.head_dim, **factory_kwargs, ) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype device = self.out_proj.weight.device return torch.empty( batch_size, max_seqlen, 2, self.num_heads_kv_per_rank, self.head_dim, dtype=dtype, device=device, ) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx) def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): """ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: rotary_cos, rotary_sin = None, None batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, alibi_slopes=alibi_slopes, ) return context def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if inference_params.seqlen_offset == 0 or not self.use_flash_attn: # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, alibi_slopes=alibi_slopes, ) return context def forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) if seqlen is not None: qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) seqlen_offset = ( 0 if inference_params is None else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) ) rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None if self.num_heads_kv == self.num_heads: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: qkv = self.rotary_emb( qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: context = self._update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: context = self._apply_rotary_update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: q = rearrange( qkv[..., : self.num_heads_per_rank * self.head_dim], "... (h d) -> ... h d", d=self.head_dim, ) kv = rearrange( qkv[..., self.num_heads_per_rank * self.head_dim :], "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim, ) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: q, kv = self.rotary_emb( q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs) else: context = torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, **kwargs ) else: context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) context = rearrange(context, "b s h d -> b s (h d)") if seqlen is not None: context = rearrange(context, "b s d -> (b s) d") out = self.out_proj(context) return out ================================================ FILE: flash_attn/modules/mlp.py ================================================ # Copyright (c) 2023, Tri Dao. import torch import torch.nn as nn import torch.nn.functional as F from torch.distributed import ProcessGroup try: from flash_attn.ops.activations import swiglu except ImportError: swiglu = None try: from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: ColumnParallelLinear, RowParallelLinear = None, None try: from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP except ImportError: FusedMLP, ParallelFusedMLP = None, None class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation=F.gelu, bias1=True, bias2=True, return_residual=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features if out_features is not None else in_features hidden_features = hidden_features if hidden_features is not None else in_features * 4 self.return_residual = return_residual self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.activation = activation self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x): y = self.fc1(x) y = self.activation(y) y = self.fc2(y) return y if not self.return_residual else (y, x) class ParallelMLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation=F.gelu, process_group: ProcessGroup = None, sequence_parallel=True, bias1=True, bias2=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() assert ColumnParallelLinear is not None, "Need to install fused_dense" assert RowParallelLinear is not None, "Need to install fused_dense" out_features = out_features if out_features is not None else in_features hidden_features = hidden_features if hidden_features is not None else in_features * 4 self.fc1 = ColumnParallelLinear( in_features, hidden_features, process_group, bias=bias1, sequence_parallel=sequence_parallel, **factory_kwargs, ) self.activation = activation self.fc2 = RowParallelLinear( hidden_features, out_features, process_group, bias=bias2, sequence_parallel=sequence_parallel, **factory_kwargs, ) def forward(self, x): y = self.fc1(x) y = self.activation(y) y = self.fc2(y) return y class GatedMlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid, bias1=True, bias2=True, multiple_of=128, return_residual=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features if out_features is not None else in_features hidden_features = ( hidden_features if hidden_features is not None else int(8 * in_features / 3) ) hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of self.return_residual = return_residual self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) self.activation = activation self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x): y = self.fc1(x) if self.activation == F.sigmoid: # Special case for GLU y = F.glu(y, dim=-1) elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU y, gate = y.chunk(2, dim=-1) y = swiglu(gate, y) else: y, gate = y.chunk(2, dim=-1) y = y * self.activation(gate) y = self.fc2(y) return y if not self.return_residual else (y, x) class ParallelGatedMlp(nn.Module): """Parallel GatedMlp""" def __init__( self, in_features, process_group, hidden_features=None, out_features=None, activation=F.sigmoid, bias1=True, bias2=True, multiple_of=128, sequence_parallel=True, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features if out_features is not None else in_features hidden_features = ( hidden_features if hidden_features is not None else int(8 * in_features / 3) ) hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError("fused_dense is not installed") self.fc1 = ColumnParallelLinear( in_features, 2 * hidden_features, process_group, bias=bias1, sequence_parallel=sequence_parallel, **factory_kwargs, ) self.activation = activation self.fc2 = RowParallelLinear( hidden_features, out_features, process_group, bias=bias2, sequence_parallel=sequence_parallel, **factory_kwargs, ) def forward(self, x): y = self.fc1(x) if self.activation == F.sigmoid: # Special case for GLU y = F.glu(y, dim=-1) else: y, gate = y.chunk(2, dim=-1) y = y * self.activation(gate) y = self.fc2(y) return y ================================================ FILE: flash_attn/ops/__init__.py ================================================ ================================================ FILE: flash_attn/ops/activations.py ================================================ # Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py import math import torch import torch.nn as nn import torch.nn.functional as F # 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2) -> 0.70710678 # sqrt(2/pi) -> 0.79788456 # this function is tanh approximation of gelu # actual gelu is: # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) @torch.jit.script def bias_gelu(y, bias): x = bias + y return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) # gradient of tanh approximation of gelu # gradient of actual gelu is: # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) @torch.jit.script def bias_gelu_back(g, y, bias): """Assume that y has shape (B, D) and bias has shape (D)""" x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( 1 + tanh_out ) grad_y = ff * g return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) class GeLUFunction(torch.autograd.Function): @staticmethod # bias is an optional argument def forward(ctx, input, bias): ctx.save_for_backward(input, bias) return bias_gelu(input, bias) @staticmethod def backward(ctx, grad_output): input, bias = ctx.saved_tensors tmp = bias_gelu_back(grad_output, input, bias) return tmp, tmp bias_gelu_impl = GeLUFunction.apply # this function is tanh approximation of gelu # actual gelu is: # x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) @torch.jit.script def gelu_fwd(x): return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) # gradient of tanh approximation of gelu # gradient of actual gelu is: # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) @torch.jit.script def gelu_bwd(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( 1 + tanh_out ) return (ff * g).to(dtype=x.dtype) class FastGeLUFunction(torch.autograd.Function): @staticmethod # bias is an optional argument def forward(ctx, input): ctx.save_for_backward(input) return gelu_fwd(input) @staticmethod def backward(ctx, grad_output): (input,) = ctx.saved_tensors tmp = gelu_bwd(grad_output, input) return tmp fast_gelu_impl = FastGeLUFunction.apply @torch.jit.script def relu_bwd(g, x): return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) @torch.jit.script def sqrelu_fwd(x): r = F.relu(x) return (r * r).to(dtype=x.dtype) @torch.jit.script def sqrelu_bwd(g, x): return (2.0 * g * F.relu(x)).to(dtype=x.dtype) swiglu_fwd_codestring = """ template T swiglu_fwd(T x, T y) { return float(x) * float(y) / (1.0f + ::exp(-float(x))); } """ swiglu_bwd_codestring = """ template void swiglu_bwd(T x, T y, T g, T& dx, T& dy) { float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); dy = float(x) * x_sigmoid * float(g); } """ swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) class SwiGLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) return swiglu_fwd(x, y) @staticmethod def backward(ctx, dout): x, y = ctx.saved_tensors return swiglu_bwd(x, y, dout) swiglu = SwiGLUFunction.apply ================================================ FILE: flash_attn/ops/fused_dense.py ================================================ # Copyright (c) 2023, Tri Dao. # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py # We make it work with pytorch amp and with bfloat16. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from functools import partial from typing import Optional # import fused_dense_cuda # from apex import fused_dense_lib as fused_dense_cuda import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.distributed import ProcessGroup from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd from flash_attn.utils.distributed import ( all_gather_raw, all_reduce, all_reduce_raw, reduce_scatter, reduce_scatter_raw, ) class FusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward( ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True ): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather_raw of x before doing the matmul. """ ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group ctx.sequence_parallel = sequence_parallel if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() if process_group is not None and sequence_parallel: # We want to kick off the all_gather early, before weight dtype conversion total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x if torch.is_autocast_enabled(): weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None weight = weight.contiguous() if process_group is not None and sequence_parallel: handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight.shape) > 65535 * 32: raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) else: ctx.save_for_backward(weight) return output if not return_residual else (output, x) @staticmethod @custom_bwd def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() if ctx.return_residual: (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x else: (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) if ctx.needs_input_grad[0]: if not ctx.return_residual: grad_input = F.linear(grad_output, weight.t()) else: grad_input = torch.addmm( grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight ) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) else: grad_input = None if ctx.needs_input_grad[1]: assert ctx.compute_weight_gradient if process_group is not None and sequence_parallel: handle_x.wait() grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] ) else: grad_weight = None grad_bias = grad_output if ctx.needs_input_grad[2] else None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() return grad_input, grad_weight, grad_bias, None, None, None def fused_dense_func( x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True, ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: return FusedDenseFunc.apply( x, weight, bias, return_residual, process_group, sequence_parallel ) else: assert process_group is None out = F.linear(x, weight, bias) return out if not return_residual else (out, x) class FusedDense(nn.Linear): def __init__( self, in_features: int, out_features: int, bias: bool = True, return_residual: bool = False, device=None, dtype=None, ) -> None: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) self.return_residual = return_residual def forward(self, x, process_group=None): """ If process_group is not None, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. """ return fused_dense_func( x, self.weight, self.bias, return_residual=self.return_residual, process_group=process_group, ) class ColumnParallelLinear(nn.Linear): def __init__( self, in_features: int, out_features: int, process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % multiple_of: raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") multiple = out_features // multiple_of # We want to split @multiple across world_size, but it could be an uneven split div = multiple // world_size mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) super().__init__( in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype ) self.process_group = process_group self.sequence_parallel = sequence_parallel def forward(self, x): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. return fused_dense_func( x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel, ) class RowParallelLinear(nn.Linear): def __init__( self, in_features: int, out_features: int, process_group: ProcessGroup, bias: bool = True, sequence_parallel=True, multiple_of=1, device=None, dtype=None, ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) if in_features % multiple_of: raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") multiple = in_features // multiple_of # We want to split @multiple across world_size, but it could be an uneven split div = multiple // world_size mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) # Only rank 0 will have bias super().__init__( local_multiple * multiple_of, out_features, bias=bias and rank == 0, device=device, dtype=dtype, ) self.process_group = process_group self.sequence_parallel = sequence_parallel def forward(self, x): """ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then a reduce_scatter of the result. """ out = fused_dense_func(x, self.weight, self.bias) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) class FusedMLPFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward( ctx, x, weight1, bias1, weight2, bias2, activation="gelu_approx", save_pre_act=True, return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True, ): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. If sequence_parallel=False, then the input is already gathered. checkpoint_lvl: 0: no recomputation in the bwd 1: recompute gelu_out / relu_out in the bwd 2: recompute pre_act and gelu_out / relu_out in the bwd """ assert -1 <= heuristic <= 4 assert activation in ["gelu_approx", "relu", "sqrelu"] if activation == "sqrelu": assert heuristic == -1 if not save_pre_act: checkpoint_lvl = 2 assert checkpoint_lvl in [0, 1, 2] ctx.return_residual = return_residual ctx.process_group = process_group ctx.sequence_parallel = sequence_parallel ctx.checkpoint_lvl = checkpoint_lvl ctx.activation = activation ctx.heuristic = heuristic if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) x = x.contiguous() if process_group is not None and sequence_parallel: # We want to kick off the all_gather early, before weight dtype conversion total_x, handle_x = all_gather_raw(x, process_group, async_op=True) else: total_x = x if torch.is_autocast_enabled(): dtype = torch.get_autocast_gpu_dtype() weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] bias1 = bias1.to(dtype=dtype) if bias1 is not None else None bias2 = bias2.to(dtype=dtype) if bias2 is not None else None weight1 = weight1.contiguous() bias1 = bias1.contiguous() if bias1 is not None else None weight2 = weight2.contiguous() bias2 = bias2.contiguous() if bias2 is not None else None if process_group is not None and sequence_parallel: handle_x.wait() batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: raise RuntimeError("fused_dense only supports matrix dims <= 2M") if heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) activation_fn = ( partial(F.gelu, approximate="tanh") if activation == "gelu_approx" else (sqrelu_fwd if activation == "sqrelu" else F.relu) ) with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) # This is before adding bias1 # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) # with torch.jit.fuser('fuser2'): # output1 = bias_gelu(pre_act, bias1) else: is_gelu = activation == "gelu_approx" output1, *rest = fused_dense_cuda.linear_act_forward( total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic ) if save_pre_act: pre_act = rest[0] output2 = F.linear(output1, weight2, bias2) if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): # For RELU the pre_act is very small (just a bit-mask) so we just save it ctx.save_for_backward(x, weight1, weight2, pre_act, output1) elif checkpoint_lvl == 1: ctx.save_for_backward(x, weight1, weight2, pre_act) elif checkpoint_lvl == 2: ctx.save_for_backward(x, weight1, weight2, bias1) output2 = output2.reshape(*batch_shape, output2.shape[-1]) return output2 if not return_residual else (output2, x) @staticmethod @custom_bwd def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() checkpoint_lvl = ctx.checkpoint_lvl activation = ctx.activation activation_fn = ( partial(F.gelu, approximate="tanh") if activation == "gelu_approx" else (sqrelu_fwd if activation == "sqrelu" else F.relu) ) if ctx.return_residual: (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel x, weight1, weight2, *rest = ctx.saved_tensors if process_group is None or not sequence_parallel: total_x = x batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() if checkpoint_lvl in [0, 1]: if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): pre_act, output1 = rest elif checkpoint_lvl == 1: (pre_act,) = rest with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) elif checkpoint_lvl == 2: (bias1,) = rest if process_group is not None and sequence_parallel: total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) else: output1, pre_act = fused_dense_cuda.linear_act_forward( total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, activation == "gelu_approx", True, ctx.heuristic, ) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) output1 = output1.reshape(batch_dim, output1.shape[-1]) pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) if ctx.needs_input_grad[3]: grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( output1, grad_output, ctx.needs_input_grad[4] ) else: grad_weight2 = None grad_bias2 = grad_output if ctx.needs_input_grad[4] else None if ctx.heuristic == -1: # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) grad_output1 = F.linear(grad_output, weight2.t()) activation_grad_fn = ( gelu_bwd if activation == "gelu_approx" else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) ) with torch.jit.fuser("fuser2"): grad_pre_act = activation_grad_fn(grad_output1, pre_act) else: # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't # just compute gelu/relu grad grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic ) if not ctx.needs_input_grad[2]: grad_bias1 = None if ctx.needs_input_grad[0]: if not ctx.return_residual: grad_input = F.linear(grad_pre_act, weight1.t()) else: grad_input = torch.addmm( grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 ) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) else: grad_input = None if ctx.heuristic == -1: if ctx.needs_input_grad[1]: if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act, ctx.needs_input_grad[2], ) else: grad_weight1 = None grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None else: if ctx.needs_input_grad[1]: if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1 = F.linear( grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() ) else: grad_weight1 = None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() return ( grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None, None, None, None, ) def fused_mlp_func( x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, bias2: Optional[Tensor] = None, activation: str = "gelu_approx", save_pre_act: bool = True, return_residual: bool = False, checkpoint_lvl: int = 0, heuristic: int = 0, process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True, ): assert activation in ["gelu_approx", "relu", "sqrelu"] dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) if ( x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible ): return FusedMLPFunc.apply( x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group, sequence_parallel, ) else: assert process_group is None pre_act = F.linear(x, weight1, bias1) activation_fn = ( partial(F.gelu, approximate="tanh") if activation == "gelu_approx" else partial(F.relu, inplace=True) ) output1 = activation_fn(pre_act) output2 = F.linear(output1, weight2, bias2) return output2 if not return_residual else (output2, x) class FusedMLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True, activation="gelu_approx", return_residual=False, checkpoint_lvl=0, heuristic="auto", device=None, dtype=None, ): """ If process_group is not None, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. Finally we do a reduce_scatter of the output. checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd 2: recompute pre_act and gelu_out in the bwd heuristic: -1: don't fuse gemm + gelu (separate kernel) 0..4: use this heuristic for the algo section in the fused gemm + gelu 'auto': heuristic will be picked automatically: For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation is slower than the unfused version. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ assert checkpoint_lvl in [0, 1, 2] assert activation in ["gelu_approx", "relu", "sqrelu"] factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features * 4 self.activation = activation self.return_residual = return_residual self.checkpoint_lvl = checkpoint_lvl self.heuristic = heuristic if activation != "sqrelu" else -1 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x, process_group=None): dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() if self.heuristic == "auto": if self.activation == "gelu_approx": if torch.cuda.get_device_capability("cuda") == (9, 0): heuristic = -1 else: cuda_ver = tuple(map(int, torch.version.cuda.split("."))) heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) else: heuristic = 0 else: heuristic = self.heuristic out = fused_mlp_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, activation=self.activation, save_pre_act=self.training, return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic, process_group=process_group, ) if self.return_residual: out, x = out if process_group is not None: out = reduce_scatter(out, process_group) return out if not self.return_residual else (out, x) class ParallelFusedMLP(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, activation="gelu_approx", process_group: ProcessGroup = None, bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic="auto", device=None, dtype=None, ): """ process_group is required. We're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. Finally we do a reduce_scatter of the output. checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd 2: recompute pre_act and gelu_out in the bwd heuristic: -1: don't fuse gemm + gelu (separate kernel) 0..4: use this heuristic for the algo section in the fused gemm + gelu 'auto': heuristic will be picked automatically: For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. """ assert checkpoint_lvl in [0, 1, 2] assert activation in ["gelu_approx", "relu", "sqrelu"] assert process_group is not None factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features * 4 self.activation = activation self.process_group = process_group self.sequence_parallel = sequence_parallel self.checkpoint_lvl = checkpoint_lvl self.heuristic = heuristic if activation != "sqrelu" else -1 self.fc1 = ColumnParallelLinear( in_features, hidden_features, process_group, bias=bias1, **factory_kwargs ) self.fc2 = RowParallelLinear( hidden_features, out_features, process_group, bias=bias2, **factory_kwargs ) def forward(self, x): dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() if self.heuristic == "auto": if self.activation == "gelu_approx": cuda_ver = tuple(map(int, torch.version.cuda.split("."))) heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) else: heuristic = 0 else: heuristic = self.heuristic out = fused_mlp_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, activation=self.activation, save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic, process_group=self.process_group, sequence_parallel=self.sequence_parallel, ) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) ================================================ FILE: flash_attn/ops/layer_norm.py ================================================ # Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py import dropout_layer_norm import torch from torch.nn import init def maybe_align(x, alignment_in_bytes=16): """Assume that x already has last dim divisible by alignment_in_bytes""" # TD [2023-07-04] I'm not 100% sure that clone will align the memory # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() def _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) residualmat = residual.view((-1, hidden_size)) if residual is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, 1.0, 0, None, residual_in_fp32, is_rms_norm, ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma def _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(xmat.shape) dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None rowscale = rowscale.view(-1) if rowscale is not None else None if colscale is not None: assert x0 is not None, "x0 is required to compute the gradient of colscale" dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, dropout_p, 1.0, 0, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual if colscale is None: return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] return dx0mat, dresidualmat, dgamma, dbeta, dcolscale def _dropout_add_layer_norm_subset_forward( x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) residualmat = residual.view((-1, hidden_size)) if residual is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm, ) # dmask is None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma def _dropout_add_layer_norm_subset_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. """ hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) dzmat = dz.view(-1, hidden_size) dxmat = dx.view(xmat.shape) if dx is not None else None x0mat = x0.view((-1, hidden_size)) if x0 is not None else None x0_subset = x0_subset.view(-1) if x0_subset is not None else None out_subset = out_subset.view(-1) if out_subset is not None else None if colscale is not None: assert x0 is not None, "x0 is required to compute the gradient of colscale" dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual if colscale is None: return dx0mat, dresidualmat, dgamma, dbeta else: dcolscale = rest[0] return dx0mat, dresidualmat, dgamma, dbeta, dcolscale def _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes""" hidden_size = gamma0.numel() x0mat = x0.view((-1, hidden_size)) x1mat = x1.view((-1, hidden_size)) if x1 is not None else None residualmat = residual.view((-1, hidden_size)) if residual is not None else None ( z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma, ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, None, residual_in_fp32, is_rms_norm, ) # dmask0 and dmask1 are None if dropout_p == 0.0 # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma def _dropout_add_layer_norm_parallel_residual_backward( dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm=False, ): """Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). """ hidden_size = gamma0.numel() xmat = x.view((-1, hidden_size)) dz0mat = dz0.view(xmat.shape) dz1mat = dz1.view(xmat.shape) if dz1 is not None else None dxmat = dx.view(xmat.shape) if dx is not None else None ( dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest, ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm, ) # dresidualmat is None if not has_residual return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma = maybe_align(gamma.contiguous(), 16) beta = maybe_align(beta.contiguous(), 16) if beta is not None else None rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, is_rms_norm, ) # Only need to save x0 if we need to compute gradient wrt colscale x0_saved = x0 if colscale is not None else None ctx.save_for_backward( xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale ) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None if not return_dmask: return ( zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) ) else: dmask = ( dmask.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask) return ( (zmat.view(x0.shape), dmask) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) ) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = maybe_align(dz.contiguous(), 16) # this happens! dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(x.shape) dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None return ( dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, None, None, None, None, None, ) class DropoutAddLayerNormSubsetFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma = maybe_align(gamma.contiguous(), 16) beta = maybe_align(beta.contiguous(), 16) if beta is not None else None colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, is_rms_norm, ) # Only need to save x0 if we need to compute gradient wrt colscale x0_saved = x0 if colscale is not None else None x_shape = (-1, *x0.shape[1:]) ctx.save_for_backward( xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset ) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.rowscale_const = rowscale_const ctx.x0_numrows = x0.shape[:-1].numel() ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta is not None z_shape = (-1, *x0.shape[1:]) if not return_dmask: return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) else: z = zmat.view(z_shape) dmask = ( dmask.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask) return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() dz = maybe_align(dz.contiguous(), 16) # this happens! dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p has_residual = ctx.has_residual dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(-1, *x.shape[1:]) dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None dcolscale = rest[0] if colscale is not None else None return ( dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, None, None, None, None, None, None, None, None, ) class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): @staticmethod def forward( ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False, ): x0 = maybe_align(x0.contiguous(), 16) x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None residual = maybe_align(residual.contiguous(), 16) if residual is not None else None gamma0 = maybe_align(gamma0.contiguous(), 16) beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None ( z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma, ) = _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32, is_rms_norm, ) ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) ctx.prenorm = prenorm ctx.dropout_p = dropout_p ctx.has_x1 = x1 is not None ctx.has_residual = residual is not None ctx.is_rms_norm = is_rms_norm ctx.has_beta = beta0 is not None z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) if not return_dmask: return z if not prenorm else (*z, xmat.view(x0.shape)) else: dmask0 = ( dmask0.view(x0.shape) if dropout_p > 0.0 else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) dmask1 = ( dmask1.view(x0.shape) if dropout_p > 0.0 and x1 is not None else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) ) ctx.mark_non_differentiable(dmask0) ctx.mark_non_differentiable(dmask1) return ( (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) ) @staticmethod def backward(ctx, dz0, dz1, *args): dz0 = maybe_align(dz0.contiguous(), 16) # this happens! dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors dropout_p = ctx.dropout_p has_x1 = ctx.has_x1 has_residual = ctx.has_residual ( dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, ) = _dropout_add_layer_norm_parallel_residual_backward( dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, ctx.is_rms_norm, ) dx0 = dx0mat.view(x.shape) dx1 = dx1mat.view(x.shape) if dx1mat is not None else None dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None return ( dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1, dbeta1 if ctx.has_beta else None, None, None, None, None, None, None, ) def layer_norm(x, weight, bias, epsilon): return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) def dropout_add_layer_norm( x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, False, return_dropout_mask, ) def dropout_add_layer_norm_subset( x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask, ) def dropout_add_layer_norm_parallel_residual( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormParallelResidualFn.apply( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, False, return_dropout_mask, ) class DropoutAddLayerNorm(torch.nn.Module): def __init__( self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.prenorm = prenorm self.p = p self.eps = eps self.residual_in_fp32 = residual_in_fp32 self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) init.zeros_(self.bias) def forward(self, x0, residual=None): return dropout_add_layer_norm( x0, residual, self.weight, self.bias, self.p if self.training else 0.0, self.eps, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32, ) ================================================ FILE: flash_attn/ops/rms_norm.py ================================================ # Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py import torch from torch.nn import init from flash_attn.ops.layer_norm import ( DropoutAddLayerNormFn, DropoutAddLayerNormParallelResidualFn, DropoutAddLayerNormSubsetFn, ) def rms_norm(x, weight, epsilon): return DropoutAddLayerNormFn.apply( x, None, weight, None, None, None, 0.0, epsilon, False, False, True ) def dropout_add_rms_norm( x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, layerscale=None, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormFn.apply( x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, True, return_dropout_mask, ) def dropout_add_rms_norm_subset( x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, x0_subset=None, out_subset=None, rowscale_const=1.0, out_numrows=0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormSubsetFn.apply( x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask, ) def dropout_add_rms_norm_parallel_residual( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False, ): """residual_in_fp32 only has an effect if residual is None. Otherwise residual dtype is residual.dtype. """ return DropoutAddLayerNormParallelResidualFn.apply( x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, True, return_dropout_mask, ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) def forward(self, x): return rms_norm(x, self.weight, self.eps) class DropoutAddRMSNorm(torch.nn.Module): def __init__( self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, device=None, dtype=None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.prenorm = prenorm self.p = p self.eps = eps self.residual_in_fp32 = residual_in_fp32 self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): init.ones_(self.weight) def forward(self, x0, residual=None): return dropout_add_rms_norm( x0, residual, self.weight, None, self.p if self.training else 0.0, self.eps, prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32, ) ================================================ FILE: flash_attn/ops/triton/__init__.py ================================================ ================================================ FILE: flash_attn/ops/triton/cross_entropy.py ================================================ # Copyright (c) 2023, Tri Dao. from typing import Tuple, Optional, Union import torch import torch.nn.functional as F import triton import triton.language as tl # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent # version of PyTorch. The following 2 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base @triton.heuristics( { "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, } ) @triton.jit def cross_entropy_fwd_kernel( loss_ptr, # data ptrs lse_ptr, z_loss_ptr, logits_ptr, labels_ptr, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes logits_row_stride, # strides BLOCK_SIZE: tl.constexpr, HAS_SMOOTHING: tl.constexpr, # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE SPLIT: tl.constexpr, PRECOMPUTED_LSE: tl.constexpr, # If LSE is already computed (also no smoothing and logit_scale == 1.0) ): row_idx = tl.program_id(0) logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) sum_logits = 0.0 # For smoothing if not PRECOMPUTED_LSE: # Statistics for online softmax m_i = -float("inf") l_i = 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): cols = col_offset + tl.arange(0, BLOCK_SIZE) logits = tl.load(logits_ptr + cols, mask=cols < n_cols, other=-float("inf")).to( tl.float32 ) * logit_scale if HAS_SMOOTHING: sum_logits += tl.sum(tl.where(cols < n_cols, logits, 0.0)) m_i_new = tl.maximum(m_i, tl.max(logits)) l_i = tl.exp(m_i - m_i_new) * l_i + tl.sum(tl.exp(logits - m_i_new)) m_i = m_i_new lse = tl.log(l_i) + m_i tl.store(lse_ptr + row_idx, lse) else: lse = tl.load(lse_ptr + row_idx) label_idx = tl.load(labels_ptr + row_idx) if label_idx == ignore_index: loss = 0.0 z_loss = 0.0 else: label_idx -= class_start_idx if label_idx >= 0 and label_idx < n_cols: logits_label = tl.load(logits_ptr + label_idx) * logit_scale if HAS_SMOOTHING: loss = ( (lse if not SPLIT else 0.0) - smoothing * sum_logits / total_classes - (1 - smoothing) * logits_label ) else: loss = (lse if not SPLIT else 0.0) - logits_label else: # If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss if HAS_SMOOTHING: loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) else: loss = 0.0 if not SPLIT: z_loss = lse_square_scale * lse * lse loss += z_loss else: z_loss = 0.0 tl.store(loss_ptr + row_idx, loss) if not SPLIT: tl.store(z_loss_ptr + row_idx, z_loss) @triton.heuristics( { "HAS_SMOOTHING": lambda args: args["smoothing"] > 0.0, } ) @triton.jit def cross_entropy_bwd_kernel( dlogits_ptr, # data ptrs dloss_ptr, logits_ptr, lse_ptr, labels_ptr, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes logits_row_stride, # strides dlogits_row_stride, dloss_row_stride, BLOCK_SIZE: tl.constexpr, HAS_SMOOTHING: tl.constexpr, ): row_idx = tl.program_id(0) col_block_idx = tl.program_id(1) logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) label_idx = tl.load(labels_ptr + row_idx) if label_idx != ignore_index: dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) else: dloss = 0.0 logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( tl.float32 ) * logit_scale lse = tl.load(lse_ptr + row_idx) probs = tl.exp(logits - lse) probs += 2.0 * lse_square_scale * lse * probs label_idx -= class_start_idx if HAS_SMOOTHING: smooth_positive = 1.0 - smoothing smooth_negative = smoothing / total_classes probs = tl.where(col_offsets == label_idx, probs - smooth_positive, probs) - smooth_negative else: probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) class CrossEntropyLoss(torch.autograd.Function): @staticmethod def forward( ctx, logits, labels, precomputed_lse=None, smoothing=0.0, logit_scale=1.0, lse_square_scale=0.0, ignore_index=-100, inplace_backward=False, process_group=None, ): # For some reason Triton generates wrong code when labels has dtype long and its address # is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index. if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: labels = F.pad(labels, (0, 1))[..., :-1] assert labels.data_ptr() % 16 == 0 assert logit_scale > 0.0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) total_classes = world_size * n_cols rank = 0 if process_group is None else torch.distributed.get_rank(process_group) class_start_idx = rank * n_cols use_precomputed_lse = precomputed_lse is not None and logit_scale == 1.0 and smoothing == 0.0 if logits.stride(-1) != 1: logits = logits.contiguous() MAX_BLOCK_SIZE = 16 * 1024 BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) num_warps = ( 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) ) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) if use_precomputed_lse: assert precomputed_lse.shape == (n_rows,) lse = precomputed_lse.contiguous() else: lse = torch.empty(n_rows, dtype=torch.float, device=logits.device) z_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(logits.device.index): cross_entropy_fwd_kernel[(n_rows,)]( losses, # data ptrs lse, z_losses, logits, labels, smoothing, logit_scale, lse_square_scale, ignore_index, total_classes, class_start_idx, n_cols, # shapes logits.stride(0), # strides BLOCK_SIZE=BLOCK_SIZE, # constants SPLIT=world_size > 1, PRECOMPUTED_LSE=use_precomputed_lse, num_warps=num_warps, ) if world_size > 1: # If there's no smoothing, if labels are in the vocab of this partition, losses contains # - predicted logit, and 0 otherwise. # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains # -0.9 * predicted logit - 0.1 * sum logit / total_classes. # For labels not in the vocab of this partition, losses contains # -0.1 * sum logit / total_classes. if world_size > 1: lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) handle_losses = torch.distributed.all_reduce( losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True ) lse = torch.logsumexp(lse_allgather, dim=0) handle_losses.wait() # After the allreduce, if there's no smoothing, the total losses are - predicted_logit, # we just have to add the (global) lse. # If there's smoothing=0.1, the total losses are # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. # Again, we just have to add the (global) lse. losses += lse if lse_square_scale != 0.0: z_losses = lse_square_scale * lse.square() z_losses.masked_fill_(labels == ignore_index, 0.0) losses += z_losses else: z_losses = torch.zeros_like(losses) losses.masked_fill_(labels == ignore_index, 0.0) ctx.save_for_backward(logits, lse, labels) ctx.mark_non_differentiable(z_losses) ctx.smoothing = smoothing ctx.logit_scale = logit_scale ctx.lse_square_scale = lse_square_scale ctx.ignore_index = ignore_index ctx.total_classes = total_classes ctx.class_start_idx = class_start_idx ctx.inplace_backward = inplace_backward return losses, z_losses @staticmethod def backward(ctx, grad_losses, grad_z_losses): del grad_z_losses # z_losses are only for logging. logits, lse, labels = ctx.saved_tensors dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) n_rows, n_cols = logits.shape BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(logits.device.index): cross_entropy_bwd_kernel[grid]( dlogits, # data ptrs grad_losses, logits, lse, labels, ctx.smoothing, ctx.logit_scale, ctx.lse_square_scale, ctx.ignore_index, ctx.total_classes, ctx.class_start_idx, n_cols, # shapes logits.stride(0), # strides dlogits.stride(0), grad_losses.stride(0), BLOCK_SIZE=BLOCK_SIZE, # constants num_warps=num_warps, ) return dlogits, None, None, None, None, None, None, None, None, None def cross_entropy_loss( logits: torch.Tensor, labels: torch.Tensor, precomputed_lse: Optional[torch.Tensor] = None, label_smoothing: float = 0.0, logit_scale: float = 1.0, lse_square_scale: float = 0.0, ignore_index=-100, inplace_backward: bool = False, process_group=None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: logits: (batch, vocab_size) labels: (batch,) label_smoothing: float logit_scale: float. Multiply logits by this scale before calculating the loss. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". ignore_index: int. If labels == ignore_index, the loss is set to 0.0. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for one part of the vocab. The loss will be aggregated across processes. Returns: losses: (batch,), float z_losses: (batch,), float """ return CrossEntropyLoss.apply( logits, labels, precomputed_lse, label_smoothing, logit_scale, lse_square_scale, ignore_index, inplace_backward, process_group, ) ================================================ FILE: flash_attn/ops/triton/k_activations.py ================================================ # Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import math from enum import Enum from typing import Optional import triton import triton.language as tl _sqrt2pi = math.sqrt(2.0 / math.pi) _sqrt1_2 = math.sqrt(1.0 / 2) _gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) class Activation(str, Enum): SquaredReLU = "squared_relu" GeLU = "gelu" GeLUApprox = "gelu_approx" LeakyReLU = "leaky_relu" ReLU = "relu" def get_triton_activation_kernel(activation: Optional[Activation]): return ( { Activation.ReLU: relu, Activation.LeakyReLU: leaky_relu, Activation.GeLU: gelu, Activation.GeLUApprox: gelu_approx, Activation.SquaredReLU: squared_relu, }[activation] if activation else None ) def get_triton_activation_bwd_kernel(activation: Optional[Activation]): return ( { Activation.ReLU: relu_grad, Activation.LeakyReLU: leaky_relu_grad, Activation.GeLU: gelu_grad, Activation.GeLUApprox: gelu_approx_grad, Activation.SquaredReLU: squared_relu_grad, }[activation] if activation else None ) @triton.jit def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 @triton.jit def cosh(x): exp_x = tl.exp(x) return (exp_x + 1.0 / exp_x) * 0.5 # a Triton implementation of the most used activations # See for instance http://arxiv.org/abs/1606.08415 for an overview # ReLU @triton.jit def relu(x): """ ReLU_ activation function .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html """ zero = 0.0 return tl.where(x >= 0, x, zero.to(x.dtype)) @triton.jit def relu_grad(x): # ReLU is different from other activations # in that it does not require the input to retrospectively compute its gradient # here the input is the downstream gradient, and we return the upstream gradient directly zero = 0.0 one = 1.0 return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) @triton.jit def squared_relu(x): """ Squared ReLU activation, as proposed in the Primer_ paper. .. _Primer: https://arxiv.org/abs/2109.08668 """ x_ = relu(x) return (x_ * x_).to(x.dtype) @triton.jit def squared_relu_grad(x): return tl.where(x >= 0, 2.0 * x, 0.0) # Leaky ReLU @triton.jit def leaky_relu(x): """ LeakyReLU_ activation .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html """ scale = 0.01 + 0.0 scale = scale.to(x.dtype) return tl.where(x >= 0, x, scale * x) @triton.jit def leaky_relu_grad(x): min_grad = 0.01 max_grad = 1 min_grad = min_grad.to(x.dtype) max_grad = max_grad.to(x.dtype) return tl.where(x >= 0, max_grad, min_grad) @triton.jit def gelu(x): """Gaussian Error Linear Unit (GELU)""" return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) @triton.jit def gelu_grad(x): cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization return cdf + x * pdf @triton.jit def gelu_approx(x): """ GeLU_ activation - Gaussian error linear unit, with tanh approximation .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf """ return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) @triton.jit def gelu_approx_grad(x): # CREDITS: Fast implementation proposed in # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( 1 + tanh_out ) ================================================ FILE: flash_attn/ops/triton/layer_norm.py ================================================ # Copyright (c) 2024, Tri Dao. # Implement dropout + residual + layer_norm / rms_norm. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. import math from typing import Optional, List import torch import torch.nn.functional as F from torch import Tensor import triton import triton.language as tl from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.utils.library import triton_op def maybe_contiguous_lastdim(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def maybe_contiguous(x): return x.contiguous() if x is not None else None def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs = [] # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 max_threads_per_block = 1024 # Default to warp size 32 if not defined by device warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block] # return [triton.Config({}, num_warps=8)] def layer_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if zero_centered_weight: weight = weight + 1.0 if weight1 is not None: weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( dtype ) if weight1 is None: return out if not prenorm else (out, x) else: out1 = F.layer_norm( x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps ).to(dtype) return (out, out1) if not prenorm else (out, out1, x) def rms_norm_ref( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, zero_centered_weight=False, dropout_mask=None, dropout_mask1=None, upcast=False, ): dtype = x.dtype if upcast: x = x.float() weight = weight.float() bias = bias.float() if bias is not None else None residual = residual.float() if residual is not None else residual x1 = x1.float() if x1 is not None else None weight1 = weight1.float() if weight1 is not None else None bias1 = bias1.float() if bias1 is not None else None if zero_centered_weight: weight = weight + 1.0 if weight1 is not None: weight1 = weight1 + 1.0 if x1 is not None: assert rowscale is None, "rowscale is not supported with parallel LayerNorm" if rowscale is not None: x = x * rowscale[..., None] if dropout_p > 0.0: if dropout_mask is not None: x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) else: x = F.dropout(x, p=dropout_p) if x1 is not None: if dropout_mask1 is not None: x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) else: x1 = F.dropout(x1, p=dropout_p) if x1 is not None: x = x + x1 if residual is not None: x = (x + residual).to(x.dtype) rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) if weight1 is None: return out if not prenorm else (out, x) else: out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( dtype ) return (out, out1) if not prenorm else (out, out1, x) @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"], ) # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) # @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) # @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) # @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual X1, W1, B1, Y1, RESIDUAL_OUT, # pointer to the residual ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, DROPOUT_MASK1, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row if HAS_X1: X1 += row * stride_x1_row if HAS_W1: Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) x *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_X1: x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) x1 *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual if STORE_RESIDUAL_OUT: tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( x: Tensor, weight: Tensor, bias: Tensor, eps: float, residual: Optional[Tensor] = None, x1: Optional[Tensor] = None, weight1: Optional[Tensor] = None, bias1: Optional[Tensor] = None, dropout_p: float = 0.0, rowscale: Optional[Tensor] = None, out_dtype: Optional[torch.dtype] = None, residual_dtype: Optional[torch.dtype] = None, zero_centered_weight: bool = False, is_rms_norm: bool = False, return_dropout_mask: bool = False, out: Optional[Tensor] = None, residual_out: Optional[Tensor] = None ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None # so that _layer_norm_fwd_impl doesn't have to return them. if out is None: out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) if residual is not None: residual_dtype = residual.dtype if residual_out is None and ( residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None ): residual_out = torch.empty_like( x, dtype=residual_dtype if residual_dtype is not None else x.dtype ) else: residual_out = None y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( x, weight, bias, eps, out, residual=residual, x1=x1, weight1=weight1, bias1=bias1, dropout_p=dropout_p, rowscale=rowscale, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, residual_out=residual_out, ) # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 if residual_out is None: residual_out = x return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 # [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema # since we're returning a tuple of tensors @triton_op("flash_attn::layer_norm_fwd_impl", mutates_args={"out", "residual_out"}, schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)") def _layer_norm_fwd_impl( x: Tensor, weight: Tensor, bias: Tensor, eps: float, out: Tensor, residual: Optional[Tensor] = None, x1: Optional[Tensor] = None, weight1: Optional[Tensor] = None, bias1: Optional[Tensor] = None, dropout_p: float = 0.0, rowscale: Optional[Tensor] = None, zero_centered_weight: bool = False, is_rms_norm: bool = False, return_dropout_mask: bool = False, residual_out: Optional[Tensor] = None ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 if residual is not None: assert residual.stride(-1) == 1 assert residual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if x1 is not None: assert x1.shape == x.shape assert rowscale is None assert x1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) assert out.shape == x.shape assert out.stride(-1) == 1 if residual_out is not None: assert residual_out.shape == x.shape assert residual_out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 ) else: seeds = None if return_dropout_mask and dropout_p > 0.0: dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) if x1 is not None: dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) else: dropout_mask1 = None else: dropout_mask, dropout_mask1 = None, None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( x, out, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask, dropout_mask1, mean, rstd, x.stride(0), out.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, # Passing bool make torch inductor very unhappy since it then tries to compare to int_max int(zero_centered_weight), is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0, dropout_mask is not None, rowscale is not None, HAS_X1=x1 is not None, HAS_W1=weight1 is not None, HAS_B1=bias1 is not None, ) return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) # @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) # @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) # @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) # @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, W1, DY1, DX1, DW1, DB1, DRESIDUAL_IN, ROWSCALE, SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dy1_row, stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_DY1: tl.constexpr, HAS_DX1: tl.constexpr, HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row if HAS_DRESIDUAL: DRESIDUAL += row_start * stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row if HAS_DY1: DY1 += row_start * stride_dy1_row if HAS_DX1: DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_DY1: dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_B1: db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if HAS_DY1: dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.0) if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy dw += dy * xhat if HAS_BIAS: db += dy if HAS_DY1: wdy += w1 * dy1 dw1 += dy1 * xhat if HAS_B1: db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd if HAS_DRESIDUAL: dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) dx += dres # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) if HAS_DX1: if HAS_DROPOUT: keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) else: dx1 = dx tl.store(DX1 + cols, dx1, mask=mask) if HAS_DROPOUT: keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_DRESIDUAL: DRESIDUAL += stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += stride_dres_in_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row if HAS_DY1: DY1 += stride_dy1_row if HAS_DX1: DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) if HAS_DY1: tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) if HAS_B1: tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( dy: Tensor, x: Tensor, weight: Tensor, bias: Tensor, eps: float, mean: Tensor, rstd: Tensor, dresidual: Optional[Tensor] = None, dy1: Optional[Tensor] = None, weight1: Optional[Tensor] = None, bias1: Optional[Tensor] = None, seeds: Optional[Tensor] = None, dropout_p: float = 0.0, rowscale: Optional[Tensor] = None, has_residual: bool = False, has_x1: bool = False, zero_centered_weight: bool = False, is_rms_norm: bool = False, x_dtype: Optional[torch.dtype] = None, recompute_output: bool = False, ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x, # which makes torch.library unhappy dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl( dy, x, weight, bias, eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, dropout_p, rowscale, has_residual, has_x1, zero_centered_weight, is_rms_norm, x_dtype=x_dtype, recompute_output=recompute_output, ) # Don't need to compute dresidual_in separately in this case if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: dresidual_in = dx if has_x1 and dropout_p == 0.0: dx1 = dx return dx, dw, db, dresidual_in, dx1, dw1, db1, y @triton_op("flash_attn::layer_norm_bwd_impl", mutates_args={}, schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)", allow_decomposition=False, # Don't let torch.compile trace inside ) def _layer_norm_bwd_impl( dy: Tensor, x: Tensor, weight: Tensor, bias: Tensor, eps: float, mean: Tensor, rstd: Tensor, dresidual: Optional[Tensor] = None, dy1: Optional[Tensor] = None, weight1: Optional[Tensor] = None, bias1: Optional[Tensor] = None, seeds: Optional[Tensor] = None, dropout_p: float = 0.0, rowscale: Optional[Tensor] = None, has_residual: bool = False, has_x1: bool = False, zero_centered_weight: bool = False, is_rms_norm: bool = False, x_dtype: Optional[torch.dtype] = None, recompute_output: bool = False, ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): M, N = x.shape assert x.stride(-1) == 1 dy = maybe_contiguous_lastdim(dy) assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: dresidual = maybe_contiguous_lastdim(dresidual) assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: dy1 = maybe_contiguous_lastdim(dy1) assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if seeds is not None: assert seeds.is_contiguous() assert seeds.shape == (M if not has_x1 else M * 2,) if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) dresidual_in = ( torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) else None ) dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None if recompute_output: assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the # latency of the gmem reads/writes, but will increase the time of summing up dw / db. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None ) _dw1 = torch.empty_like(_dw) if weight1 is not None else None _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid]( x, weight, bias, y, dy, dx, _dw, _db, dresidual, weight1, dy1, dx1, _dw1, _db1, dresidual_in, rowscale, seeds, mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, dy1.stride(0) if dy1 is not None else 0, dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, dropout_p, # Passing bool make torch inductor very unhappy since it then tries to compare to int_max int(zero_centered_weight), rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, dropout_p > 0.0, HAS_ROWSCALE=rowscale is not None, HAS_DY1=dy1 is not None, HAS_DX1=dx1 is not None, HAS_B1=bias1 is not None, RECOMPUTE_OUTPUT=y is not None, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx return dx, dw, db, dresidual_in, dx1, dw1, db1, y class LayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out_dtype=None, out=None, residual_out=None ): x_shape_og = x.shape # reshape input data into 2D tensor x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) weight = weight.contiguous() bias = maybe_contiguous(bias) weight1 = maybe_contiguous(weight1) bias1 = maybe_contiguous(bias1) if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) if out is not None: out = out.reshape(-1, out.shape[-1]) if residual_out is not None: residual_out = residual_out.reshape(-1, residual_out.shape[-1]) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p=dropout_p, rowscale=rowscale, out_dtype=out_dtype, residual_dtype=residual_dtype, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, residual_out=residual_out, ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) @staticmethod def backward(ctx, dy, *args): x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) assert dresidual.shape == x.shape else: dresidual = None dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd( dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, ctx.dropout_p, rowscale, ctx.has_residual, ctx.has_x1, ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=False, ) return ( dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, dw1, db1, None, None, None, None, None, None, None, None, None, None, None, ) def layer_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out_dtype=None, out=None, residual_out=None ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, is_rms_norm, return_dropout_mask, out_dtype, out, residual_out ) def rms_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, return_dropout_mask=False, out_dtype=None, out=None, residual_out=None ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, True, return_dropout_mask, out_dtype, out, residual_out ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if dropout_p > 0.0: self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): if not self.zero_centered_weight: torch.nn.init.ones_(self.weight) else: torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( x, self.weight, self.bias, residual=residual, eps=self.eps, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, zero_centered_weight=self.zero_centered_weight, ) class LayerNormLinearFn(torch.autograd.Function): @staticmethod @custom_fwd def forward( ctx, x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False, ): x_shape_og = x.shape # reshape input data into 2D tensor x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) if residual is not None: assert residual.shape == x_shape_og residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1])) norm_weight = norm_weight.contiguous() norm_bias = maybe_contiguous(norm_bias) residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( x, norm_weight, norm_bias, eps, residual, out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), residual_dtype=residual_dtype, is_rms_norm=is_rms_norm, ) y = y.reshape(x_shape_og) dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype linear_weight = linear_weight.to(dtype) linear_bias = linear_bias.to(dtype) if linear_bias is not None else None out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) # We don't store y, will be recomputed in the backward pass to save memory ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.linear_bias_is_none = linear_bias is None return out if not prenorm else (out, residual_out.reshape(x_shape_og)) @staticmethod @custom_bwd def backward(ctx, dout, *args): x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors dout = dout.reshape(-1, dout.shape[-1]) dy = F.linear(dout, linear_weight.t()) dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) dy = maybe_contiguous_lastdim(dy) assert dy.shape == x.shape if ctx.prenorm: dresidual = args[0] dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1])) assert dresidual.shape == x.shape else: dresidual = None dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( dy, x, norm_weight, norm_bias, ctx.eps, mean, rstd, dresidual=dresidual, has_residual=ctx.has_residual, is_rms_norm=ctx.is_rms_norm, x_dtype=ctx.x_dtype, recompute_output=True, ) dlinear_weight = torch.einsum("bo,bi->oi", dout, y) return ( dx.reshape(ctx.x_shape_og), dnorm_weight, dnorm_bias, dlinear_weight, dlinear_bias, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, None, None, None, None, ) def layer_norm_linear_fn( x, norm_weight, norm_bias, linear_weight, linear_bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False, is_rms_norm=False, ): return LayerNormLinearFn.apply( x, norm_weight, norm_bias, linear_weight, linear_bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm, ) ================================================ FILE: flash_attn/ops/triton/linear.py ================================================ # Adapted from https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py # and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py from typing import Optional import torch import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from flash_attn.ops.triton.k_activations import ( gelu, gelu_approx, gelu_approx_grad, gelu_grad, squared_relu, squared_relu_grad, ) # CREDITS: Initially inspired by the Triton tutorial on matrix multiplications def init_to_zero(name): return lambda nargs: nargs[name].zero_() def get_configs_io_bound(): configs = [] for num_stages in [2, 3, 4, 5, 6]: for block_m in [16, 32]: for block_k in [32, 64]: for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 configs.append( triton.Config( { "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1, }, num_stages=num_stages, num_warps=num_warps, ) ) # split_k not used # for split_k in [2, 4, 8, 16]: # configs.append(triton.Config( # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs @triton.autotune( configs=[ triton.Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 ), # good for int8 triton.Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8, ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8, ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 ), ] + get_configs_io_bound(), key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], prune_configs_by={ "early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10, }, ) @triton.heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) @triton.jit def kernel_fwd( C, # Pointers to matrices ACT_INPUT, A, B, bias, # Matrix dimensions M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_cm, # stride_cn, # Assume that stride_cn == 1 stride_am, stride_ak, stride_bn, stride_bk, # Meta-parameters BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # split k not used, not performant with activation, kept because early_config_prune is expecting it SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, A_ROWMAJOR: tl.constexpr, B_COLMAJOR: tl.constexpr, BIAS: tl.constexpr, SAVE_ACT_INPUT: tl.constexpr, ACTIVATION: tl.constexpr, ): """ Kernel for computing Out = activation(A x W + C) - Input has shape (M, K) - Weight has shape (K, N) - Bias has shape (N,) - Output has shape (M, N) - ActInputs (optional) has shape (M, N) 'ActInputs' optionally saves the A x W + C intermediate for backward computations This kernel will consolidate over K """ pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # now compute the block that each program will go through # rm (resp. rn) denotes a range of indices # for rows (resp. col) of C rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # trick to avoid masking on M and N axis ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = tl.arange(0, BLOCK_K) if A_ROWMAJOR: A = A + (ram[:, None] * stride_am + rk[None, :]) else: A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) if B_COLMAJOR: B = B + (rk[:, None] + rbn[None, :] * stride_bn) else: B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if EVEN_K: a = tl.load(A) b = tl.load(B) else: a = tl.load(A, mask=rk[None, :] < k, other=0.0) b = tl.load(B, mask=rk[:, None] < k, other=0.0) acc += tl.dot(a, b) if A_ROWMAJOR: A += BLOCK_K else: A += BLOCK_K * stride_ak if B_COLMAJOR: B += BLOCK_K else: B += BLOCK_K * stride_bk # Putting bias after the matmul (instead of before) is faster, idk why if BIAS: bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) acc += bias[None, :] # optional: save the activation inputs if SAVE_ACT_INPUT: # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] tl.store(act_in_ptrs, acc) # optional: fused activation (while the data is in shared memory) if ACTIVATION == "gelu": acc = gelu(acc) elif ACTIVATION == "gelu_approx": acc = gelu_approx(acc) elif ACTIVATION == "squared_relu": acc = squared_relu(acc) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # write back result # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn C = C + rm[:, None] * stride_cm + rn[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :] tl.store(C, acc) def triton_linear_act( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: str = "id", save_act_input: bool = False, ) -> torch.Tensor: """ Compute e = activation(x @ weight.T + bias). This wrapper kicks the `kernel_fwd` Triton kernel :param x: input tensor :param weight: weight matrix :param bias: an optional bias tensor :param activation: Activation name. Needs to be a Triton kernel. :param act_input: an optional tensor to save the activation inputs (for backward) :return: result tensor """ # if torch.is_autocast_enabled(): # dtype = torch.get_autocast_gpu_dtype() # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = batch_shape.numel() x_reshaped = x.reshape(batch_dim, n) if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: x_reshaped = x_reshaped.contiguous() if weight.stride(0) > 1 and weight.stride(1) > 1: weight = weight.contiguous() bias = bias.contiguous() if bias is not None else None assert ( x.dtype == weight.dtype ), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" if bias is not None: assert ( x.dtype == bias.dtype ), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" assert ( x_reshaped.shape[1] == weight.shape[1] ), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" assert ( bias is None or bias.shape[0] == weight.shape[0] ), "Incompatible dimensions in between weight and bias" M, K = x_reshaped.shape N, K = weight.shape output = torch.empty((M, N), device=x.device, dtype=x.dtype) act_input = torch.empty_like(output) if save_act_input else None # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa kernel_fwd[grid]( output, act_input, x_reshaped, weight, # data ptrs bias if bias is not None else x, # auto skip bias if not present M, # shapes N, K, M // 32, # key for triton cache (limit number of compilations) N // 32, K // 32, stride_cm=output.stride(0), # strides # stride_cn=output.stride(1), stride_am=x_reshaped.stride(0), stride_ak=x_reshaped.stride(1), stride_bk=weight.stride(1), stride_bn=weight.stride(0), BIAS=bias is not None, # optional fused bias SAVE_ACT_INPUT=save_act_input, # optional save activation inputs ACTIVATION=activation, # optional fused activation A_ROWMAJOR=x_reshaped.stride(1) == 1, B_COLMAJOR=weight.stride(1) == 1, GROUP_M=8, # speed optimization: group the programs ) if not save_act_input: return output.reshape(*batch_shape, output.shape[-1]) else: return ( output.reshape(*batch_shape, output.shape[-1]), act_input.reshape(*batch_shape, act_input.shape[-1]), ) @triton.autotune( configs=[ triton.Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 ), # good for int8 triton.Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8, ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8, ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 ), ] + get_configs_io_bound(), key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], prune_configs_by={ "early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10, }, ) @triton.heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, } ) @triton.jit def kernel_bwd( C, # Pointers to matrices ACT_INPUT, A, B, # Matrix dimensions M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_cm, # stride_cn, # Assume that stride_cn == 1 stride_am, stride_ak, stride_bk, stride_bn, # Meta-parameters BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # split k not used, not performant with activation, kept because early_config_prune is expecting it SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, ACTIVATION: tl.constexpr, ): """ Kernel for computing Out = activation(A x W + C) - Input has shape (M, K) - Weight has shape (K, N) - Output has shape (M, N) - ActInputs (optional) has shape (M, N) 'ActInputs' optionally saves the A x W + C intermediate for backward computations This kernel will consolidate over K """ pid = tl.program_id(axis=0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # now compute the block that each program will go through # rm (resp. rn) denotes a range of indices # for rows (resp. col) of C rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # trick to avoid masking on M and N axis ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) rk = tl.arange(0, BLOCK_K) A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): if EVEN_K: a = tl.load(A) b = tl.load(B) else: a = tl.load(A, mask=rk[None, :] < k, other=0.0) b = tl.load(B, mask=rk[:, None] < k, other=0.0) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # optional: fused activation (while the data is in shared memory) if ACTIVATION != "id": act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] act_input = tl.load(act_in_ptrs).to(acc.dtype) if ACTIVATION == "gelu": acc *= gelu_grad(act_input) elif ACTIVATION == "gelu_approx": acc *= gelu_approx_grad(act_input) elif ACTIVATION == "squared_relu": acc *= squared_relu_grad(act_input) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # write back result C = C + rm[:, None] * stride_cm + rn[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :] tl.store(C, acc, mask=mask) def triton_dgrad_act( grad_output: torch.Tensor, weight: torch.Tensor, activation: str = "id", act_input: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute e = activation(grad_output @ weight + bias). This wrapper kicks the `kernel_fwd` Triton kernel :param grad_output: input tensor :param weight: weight matrix :param activation: Activation name. Needs to be a Triton kernel. :param act_input: an optional tensor to save the activation inputs (for backward) :return: result tensor """ assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] batch_dim = batch_shape.numel() grad_output_reshaped = grad_output.reshape(batch_dim, n) if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: grad_output_reshaped = grad_output_reshaped.contiguous() if weight.stride(0) > 1 and weight.stride(1) > 1: weight = weight.contiguous() assert ( grad_output.dtype == weight.dtype ), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" assert ( grad_output_reshaped.shape[1] == weight.shape[0] ), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" if activation != "id": assert act_input is not None, f"act_input is required for activation {activation}" # M, N, K in bwd are different from M, N, K in fwd M, K = grad_output_reshaped.shape K, N = weight.shape grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa kernel_bwd[grid]( grad_input, act_input, grad_output_reshaped, weight, # data ptrs M, # shapes N, K, M // 32, # key for triton cache (limit number of compilations) N // 32, K // 32, stride_cm=grad_input.stride(0), # strides # stride_cn=grad_input.stride(1), stride_am=grad_output_reshaped.stride(0), stride_ak=grad_output_reshaped.stride(1), stride_bk=weight.stride(0), stride_bn=weight.stride(1), ACTIVATION=activation, # optional fused activation GROUP_M=8, # speed optimization: group the programs ) return grad_input.reshape(*batch_shape, grad_input.shape[-1]) ================================================ FILE: flash_attn/ops/triton/mlp.py ================================================ # The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared # to naive implementation. import fused_dense_lib as fused_dense_cuda import torch import torch.nn as nn import torch.nn.functional as F from flash_attn.utils.torch import custom_fwd, custom_bwd from flash_attn.ops.activations import sqrelu_bwd, sqrelu_fwd from flash_attn.ops.triton.linear import triton_dgrad_act, triton_linear_act class FusedDenseSqreluDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): """checkpoint_lvl: 0: no recomputation in the bwd 1: recompute gelu_out in the bwd 2: recompute act_input and gelu_out in the bwd """ if torch.is_autocast_enabled(): dtype = torch.get_autocast_gpu_dtype() x, weight1, bias1, weight2, bias2 = [ a.to(dtype=dtype) for a in [x, weight1, bias1, weight2, bias2] ] is_bf16 = x.dtype == torch.bfloat16 assert checkpoint_lvl in [0, 1, 2] x = x.contiguous() weight1 = weight1.contiguous() bias1 = bias1.contiguous() weight2 = weight2.contiguous() bias2 = bias2.contiguous() batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = batch_shape.numel() if is_bf16: act_input = fused_dense_cuda.linear_bias_forward( x.reshape(batch_dim, n), weight1, bias1 ) output1 = sqrelu_fwd(act_input) else: save_act_input = checkpoint_lvl != 2 result = triton_linear_act( x.reshape(batch_dim, n), weight1, bias1, activation="squared_relu", save_act_input=save_act_input, ) if save_act_input: output1, act_input = result else: output1 = result output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) ctx.checkpoint_lvl = checkpoint_lvl if checkpoint_lvl == 0: ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) elif checkpoint_lvl == 1: ctx.save_for_backward(x, weight1, bias1, weight2, act_input) elif checkpoint_lvl == 2: ctx.save_for_backward(x, weight1, bias1, weight2) return output2.reshape(*batch_shape, output2.shape[-1]) @staticmethod @custom_bwd def backward(ctx, grad_output): grad_output = grad_output.contiguous() checkpoint_lvl = ctx.checkpoint_lvl x, weight1, bias1, weight2, *rest = ctx.saved_tensors batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = batch_shape.numel() is_bf16 = x.dtype == torch.bfloat16 if checkpoint_lvl == 0: act_input, output1 = rest elif checkpoint_lvl == 1: (act_input,) = rest output1 = sqrelu_fwd(act_input) elif checkpoint_lvl == 2: if is_bf16: act_input = fused_dense_cuda.linear_bias_forward( x.reshape(batch_dim, n), weight1, bias1 ) output1 = sqrelu_fwd(act_input) else: output1, act_input = triton_linear_act( x.reshape(batch_dim, n), weight1, bias1, activation="squared_relu", save_act_input=True, ) if is_bf16: grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) grad_output1 = grad_output @ weight2 grad_act_input = sqrelu_bwd(grad_output1, act_input) grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( x.reshape(batch_dim, n), weight1, grad_act_input ) else: grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) grad_act_input = triton_dgrad_act( grad_output, weight2, activation="squared_relu", act_input=act_input ) grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( x.reshape(batch_dim, n), weight1, grad_act_input ) return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply class FusedDenseSqreluDense(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, bias1=True, bias2=True, checkpoint_lvl=0, device=None, dtype=None, ): """ checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd 2: recompute gelu_in and gelu_out in the bwd """ assert checkpoint_lvl in [0, 1, 2] factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features * 4 assert bias1 == True, "DenseSqreluDense module without bias is currently not supported" assert bias2 == True, "DenseSqreluDense module without bias is currently not supported" self.checkpoint_lvl = checkpoint_lvl self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x): assert x.is_cuda return fused_dense_sqrelu_dense_function( x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias, self.checkpoint_lvl ) ================================================ FILE: flash_attn/ops/triton/rotary.py ================================================ # Copyright (c) 2025, Tri Dao. # As of 2025-04-23, we require triton >= 3.0 from typing import Optional, Union import torch import triton import triton.language as tl @triton.jit def rotary_kernel( OUT, # Pointers to matrices X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, nheads, seqlen_ro, # strides stride_out_batch, stride_out_seqlen, stride_out_nheads, stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads, stride_x_headdim, # Meta-parameters # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 ROTARY_DIM: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_H: tl.constexpr, BLOCK_M: tl.constexpr, ): BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) ROTARY_DIM_HALF = ROTARY_DIM // 2 pid_head = tl.program_id(axis=0) pid_m = tl.program_id(axis=1) pid_batch = tl.program_id(axis=2) if not IS_VARLEN: X = X + pid_batch * stride_x_batch OUT = OUT + pid_batch * stride_out_batch else: start_idx = tl.load(CU_SEQLENS + pid_batch) seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx X = X + start_idx * stride_x_seqlen OUT = OUT + start_idx * stride_out_seqlen if pid_m * BLOCK_M >= seqlen: return rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) if not IS_SEQLEN_OFFSETS_TENSOR: rm_cs = rm + SEQLEN_OFFSETS else: rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) rk_half = tl.arange(0, BLOCK_K // 2) COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) if CONJUGATE: sin = -sin if not INTERLEAVED: # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) o0 = x0 * cos - x1 * sin o1 = x0 * sin + x1 * cos tl.store(OUT, o0, mask=mask) tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) else: rk = tl.arange(0, BLOCK_K) X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) x = tl.load(X, mask=mask, other=0.0).to(tl.float32) x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) o0 = x0 * cos - x1 * sin o1 = x0 * sin + x1 * cos o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) tl.store(OUT, o, mask=mask) def apply_rotary( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, interleaved=False, inplace=False, conjugate=False, ) -> torch.Tensor: """ Arguments: x: (batch, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim). cos: (seqlen_ro, rotary_dim / 2) sin: (seqlen_ro, rotary_dim / 2) seqlen_offsets: integer or integer tensor of size (batch,) cu_seqlens: (batch + 1,) or None max_seqlen: int Returns: y: (batch, seqlen, nheads, headdim) """ is_varlen = cu_seqlens is not None if not is_varlen: batch, seqlen, nheads, headdim = x.shape else: assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" total_seqlen, nheads, headdim = x.shape batch_p_1 = cu_seqlens.shape[0] batch = batch_p_1 - 1 seqlen = max_seqlen seqlen_ro, rotary_dim = cos.shape assert sin.shape == cos.shape rotary_dim *= 2 assert rotary_dim <= headdim, "rotary_dim must be <= headdim" assert headdim <= 256, "Only support headdim <= 256" assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" cos, sin = cos.contiguous(), sin.contiguous() if isinstance(seqlen_offsets, torch.Tensor): assert seqlen_offsets.shape == (batch,) assert seqlen_offsets.dtype in [torch.int32, torch.int64] seqlen_offsets = seqlen_offsets.contiguous() else: assert seqlen_offsets + seqlen <= seqlen_ro output = torch.empty_like(x) if not inplace else x if rotary_dim < headdim and not inplace: output[..., rotary_dim:].copy_(x[..., rotary_dim:]) grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa BLOCK_M = 8 if rotary_dim <= 128 else 4 # Need this, otherwise Triton tries to launch from cuda:0 and we get # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) with torch.cuda.device(x.device.index): torch.library.wrap_triton(rotary_kernel)[grid]( output, # data ptrs x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, # shapes nheads, seqlen_ro, output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride output.stride(-2), # nheads_stride output.stride(-1), # headdim_stride x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 x.stride(-3), # seqlen stride or total_seqlen_stride x.stride(-2), # nheads stride x.stride(-1), # headdim stride rotary_dim, isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M=BLOCK_M, BLOCK_H=2, ) return output ================================================ FILE: flash_attn/pyproject.toml ================================================ [tool.black] line-length = 100 target-version = 'py39' [tool.ruff] line-length = 100 target-version = 'py39' ================================================ FILE: flash_attn/utils/__init__.py ================================================ ================================================ FILE: flash_attn/utils/benchmark.py ================================================ # Copyright (c) 2023, Tri Dao. """ Useful functions for writing test code. """ import torch import torch.utils.benchmark as benchmark def benchmark_forward( fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs ): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" if verbose: print(desc, "- Forward pass") def amp_wrapper(*inputs, **kwinputs): with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): fn(*inputs, **kwinputs) t = benchmark.Timer( stmt="fn_amp(*inputs, **kwinputs)", globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_backward( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the backward pass of an arbitrary function.""" if verbose: print(desc, "- Backward pass") with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError("Grad shape does not match output shape") def f(*inputs, y, grad): # Set .grad to None to avoid extra operation of gradient accumulation for x in inputs: if isinstance(x, torch.Tensor): x.grad = None y.backward(grad, retain_graph=True) t = benchmark.Timer( stmt="f(*inputs, y=y, grad=grad)", globals={"f": f, "inputs": inputs, "y": y, "grad": grad}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_combined( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" if verbose: print(desc, "- Forward + Backward pass") with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] if grad is None: grad = torch.randn_like(y) else: if grad.shape != y.shape: raise RuntimeError("Grad shape does not match output shape") def f(grad, *inputs, **kwinputs): for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): y = fn(*inputs, **kwinputs) if type(y) is tuple: y = y[0] y.backward(grad, retain_graph=True) t = benchmark.Timer( stmt="f(grad, *inputs, **kwinputs)", globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(m) return t, m def benchmark_fwd_bwd( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" return ( benchmark_forward( fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_backward( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), ) def benchmark_all( fn, *inputs, grad=None, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs, ): """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function.""" return ( benchmark_forward( fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_backward( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), benchmark_combined( fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, amp=amp, amp_dtype=amp_dtype, **kwinputs, ), ) def pytorch_profiler( fn, *inputs, trace_filename=None, backward=False, amp=False, amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs, ): """Wrap benchmark functions in Pytorch profiler to see CUDA information.""" if backward: with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] g = torch.randn_like(out) for _ in range(30): # Warm up if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] # Backward should be done outside autocast if backward: out.backward(g, retain_graph=True) activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [ torch.profiler.ProfilerActivity.CUDA ] with torch.profiler.profile( activities=activities, record_shapes=True, # profile_memory=True, with_stack=True, ) as prof: if backward: for x in inputs: if isinstance(x, torch.Tensor): x.grad = None with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): out = fn(*inputs, **kwinputs) if type(out) is tuple: out = out[0] if backward: out.backward(g, retain_graph=True) if verbose: # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) print(prof.key_averages().table(row_limit=50)) if trace_filename is not None: prof.export_chrome_trace(trace_filename) def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() fn(*inputs, **kwinputs) torch.cuda.synchronize() mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000) if verbose: print(f"{desc} max memory: {mem}GB") torch.cuda.empty_cache() return mem ================================================ FILE: flash_attn/utils/distributed.py ================================================ from typing import Optional import torch from torch import Tensor from torch.distributed import ProcessGroup # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent # version of PyTorch. The following 4 lines are for backward compatibility with # older PyTorch. if "all_gather_into_tensor" not in dir(torch.distributed): torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base if "reduce_scatter_tensor" not in dir(torch.distributed): torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base # Raw operation, does not support autograd, but does support async def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) output = torch.empty( world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device ) handle = torch.distributed.all_gather_into_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) return output, handle # Raw operation, does not support autograd, but does support async def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 output = torch.empty( input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device ) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) return output, handle # Raw operation, does not support autograd, but does support async def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): input_ = input_.contiguous() handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) return input_, handle class AllGatherFunc(torch.autograd.Function): """Gather the input from sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = all_gather_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) return grad_input, None # Supports autograd, but does not support async all_gather = AllGatherFunc.apply class ReduceScatterFunc(torch.autograd.Function): """Reduce scatter the input from the sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = reduce_scatter_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): grad_input, _ = all_gather_raw(grad_output, ctx.process_group) return grad_input, None # Supports autograd, but does not support async reduce_scatter = ReduceScatterFunc.apply class AllReduceFunc(torch.autograd.Function): """Gather the input from sequence parallel region and concatenate.""" @staticmethod def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: ctx.process_group = process_group output, _ = all_reduce_raw(input_, process_group) return output @staticmethod def backward(ctx, grad_output: Tensor): return grad_output, None # Supports autograd, but does not support async all_reduce = AllReduceFunc.apply def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): # We want to iterate over parameters with _shared_params=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). pamams_shared = { name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) } for _, p in sorted(pamams_shared.items()): with torch.no_grad(): # Broadcast needs src to be global rank, not group rank torch.distributed.broadcast( p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group ) # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): # We want to iterate over parameters with _sequence_parallel=True in the same order, # as different ranks might have different number of parameters (e.g., only rank 0 has bias). params_seqparallel = { name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) } grads = [p.grad for _, p in sorted(params_seqparallel.items())] if grads: with torch.no_grad(): coalesced = torch._utils._flatten_dense_tensors(grads) torch.distributed.all_reduce(coalesced, group=process_group) for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: """Get the dim for the local rank derived from splitting dim on world_size processes. The split may not be even across the world_size processes. """ multiple = dim // multiple_of div = multiple // world_size mod = multiple % world_size local_multiple = div + int(local_rank < mod) return local_multiple * multiple_of ================================================ FILE: flash_attn/utils/generation.py ================================================ # Copyright (c) 2023, Tri Dao. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 import gc import time from collections import namedtuple from dataclasses import dataclass, field from functools import partial from typing import Callable, Optional, Sequence, Union import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function try: from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput except ImportError: GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) @dataclass class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" max_seqlen: int max_batch_size: int seqlen_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) lengths_per_sample: Optional[Tensor] = None def reset(self, max_seqlen, max_batch_size): self.max_seqlen = max_seqlen self.max_batch_size = max_batch_size self.seqlen_offset = 0 if self.lengths_per_sample is not None: self.lengths_per_sample.zero_() # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 def modify_logits_for_top_k_filtering(logits, top_k): """Set the logits for none top-k values to -inf. Done in-place.""" indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(indices_to_remove, float("-Inf")) # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf. Done in-place.""" if top_p <= 0.0 or top_p >= 1.0: return # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits.masked_fill_(indices_to_remove, float("-inf")) def sample(logits, top_k=1, top_p=0.0, temperature=1.0): """Sample from top-k logits. Arguments: logits: Tensor of shape (batch_size, vocab_size) """ if top_k == 1: # Short-circuit for greedy decoding return logits.argmax(dim=-1) else: if top_p > 0.0: assert top_p <= 1.0, "top-p should be in (0, 1]." if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check logits_top, indices = torch.topk(logits, top_k, dim=-1) if temperature != 1.0: logits_top /= temperature modify_logits_for_top_p_filtering(logits_top, top_p) return indices[ torch.arange(indices.shape[0], device=indices.device), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), ] else: # Clone so that when we modify for top_p we don't change the original logits logits_top = logits / temperature if temperature != 1.0 else logits.clone() modify_logits_for_top_p_filtering(logits_top, top_p) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( dim=-1 ) @torch.inference_mode() def decode( input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1, cg=False, enable_timing=False, ): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the logits, the next token is taken from the teacher_outputs. Useful for testing. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 if cg: if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( model, model._decoding_cache, batch_size, seqlen_og, max_length, tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params inference_params.reset(max_length, batch_size) else: inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params): decoding = inference_params.seqlen_offset > 0 if decoding: position_ids = torch.full( (batch_size, 1), inference_params.seqlen_offset, dtype=torch.long, device=input_ids.device, ) else: position_ids = None if not cg or not decoding: logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=1, ).logits.squeeze(dim=1) else: logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset ).squeeze(dim=1) return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(logits, inference_params): if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: token = teacher_outputs[:, inference_params.seqlen_offset] # return rearrange(token, "b -> b 1") return token.unsqueeze(1) def should_stop(current_token, inference_params): if inference_params.seqlen_offset == 0: return False if eos_token_id is not None and (current_token == eos_token_id).all(): return True if inference_params.seqlen_offset >= max_length - 1: return True return False start = torch.cuda.Event(enable_timing=enable_timing) end = torch.cuda.Event(enable_timing=enable_timing) if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() start.record() scores, sequences = [], [input_ids] while not should_stop(sequences[-1], inference_params): scores.append(get_logits(sequences[-1], inference_params)) inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_tokens(scores[-1], inference_params)) if enable_timing: end.record() if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): """Algorithm 1 from [1] [1] Fast Inference from Transformers via Speculative Decoding Yaniv Leviathan, Matan Kalman, Yossi Matias https://arxiv.org/abs/2211.17192 Arguments: logits: Tensor of shape (batch_size, seqlen + 1, vocab_size) logits_draft: Tensor of shape (batch_size, seqlen, vocab_size) tokens_draft: Tensor of shape (batch_size, seqlen) Return: tokens: Tensor of shape (batch_size, seqlen + 1) num_generated_tokens: Tensor of shape (batch_size), with value in [1, seqlen + 1]. For each sequence in the batch, the number of valid tokens that were sampled by speculative sampling. """ batch, seqlen_p_1, vocab_size = logits.shape seqlen = seqlen_p_1 - 1 assert logits_draft.shape == (batch, seqlen, vocab_size) assert tokens_draft.shape == (batch, seqlen) assert tokens_draft.dtype in [torch.int64, torch.int32] # TODO: if top_k = 1 we can simplify things and only work with indices if top_p > 0.0: assert top_p <= 1.0, "top-p should be in (0, 1]." # Clone so that when we modify for top_p we don't change the original logits logits = logits / temperature if temperature != 1.0 else logits.clone() logits_draft = logits_draft / temperature if temperature != 1.0 else logits_draft.clone() if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check modify_logits_for_top_k_filtering(logits, top_k) modify_logits_for_top_k_filtering(logits_draft, top_k) modify_logits_for_top_p_filtering(logits, top_p) modify_logits_for_top_p_filtering(logits_draft, top_p) probs = torch.softmax(logits, dim=-1) probs_draft = torch.softmax(logits_draft, dim=-1) gather = lambda probs, tokens: rearrange( probs.gather(dim=-1, index=rearrange(tokens, "... -> ... 1")), "... 1 -> ..." ) # (batch, seqlen) accepted = torch.rand(batch, seqlen, device=probs.device) * gather( probs_draft, tokens_draft ) <= gather(probs[:, :-1], tokens_draft) accepted_all = accepted.all(dim=-1) # (batch,) first_rejected_idx = torch.where(accepted_all, seqlen, accepted.int().argmin(dim=-1)) probs_diff = torch.clamp(probs[:, :-1] - probs_draft, min=0.0) # torch.multinomial can deal with unnormalized probabilities # probs_diff /= probs_diff.sum(dim=-1, keepdim=True) resample_probs = torch.cat([probs_diff, probs[:, -1:]], dim=1) resample_probs = rearrange( resample_probs.gather(dim=1, index=repeat(first_rejected_idx, "b -> b 1 d", d=vocab_size)), "b 1 d -> b d", ) resample = torch.multinomial(resample_probs, num_samples=1).squeeze(dim=-1) # (batch,) tokens = F.pad(tokens_draft, (0, 1)) tokens[:, first_rejected_idx] = resample return tokens, first_rejected_idx + 1 @torch.inference_mode() def decode_speculative( input_ids, model, model_draft, max_length, speculative_lookahead=3, top_k=1, top_p=0.0, temperature=1.0, eos_token_id=None, vocab_size=None, tensor_parallel=1, cg=False, enable_timing=False, debug=False, ): """ TD: WIP, for my own understanding, lightly tested. Only support batch_size == 1 for now. Speculative decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" if cg: if not hasattr(model_draft, "_decoding_cache"): model_draft._decoding_cache = None model_draft._decoding_cache = update_graph_cache( model_draft, model_draft._decoding_cache, batch_size, seqlen_og, max_length, # draft model needs to process either 1 or 2 tokens at a time decoding_seqlens=(1, 2), tensor_parallel=tensor_parallel, ) inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft.reset(max_length, batch_size) if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( model, model._decoding_cache, batch_size, seqlen_og, max_length, decoding_seqlens=range(1, speculative_lookahead + 2), tensor_parallel=tensor_parallel, ) inference_params = model._decoding_cache.inference_params inference_params.reset(max_length, batch_size) else: inference_params_draft = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) def get_logits(input_ids, inference_params, model, num_last_tokens=1, cg=False): decoding = inference_params.seqlen_offset > 0 if decoding: seqlen = input_ids.shape[1] # if inference_params.lengths_per_sample is None: # TODO: in the case of batched decoding where each sequence has a different length, # we need to compute the position_ids for each sequence using lengths_per_sample if True: cache_seqlens = torch.full( (input_ids.shape[0],), inference_params.seqlen_offset, dtype=torch.int32, device=input_ids.device, ) else: cache_seqlens = inference_params.lengths_per_sample position_ids = cache_seqlens[:, None] + torch.arange( seqlen, dtype=torch.long, device=input_ids.device ) else: position_ids = None if not cg or not decoding: logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=num_last_tokens, ).logits else: # NOTE: careful, CUDA graph is set to have num_last_tokens=input_ids.shape[1]. # This might not be compatible the num_last_tokens used here. assert num_last_tokens <= input_ids.shape[1] logits = model._decoding_cache.run( input_ids, position_ids, inference_params.seqlen_offset )[:, -num_last_tokens:] return logits[..., :vocab_size] if vocab_size is not None else logits def sample_tokens(input_ids, get_logits_fn, inference_params, sample_fn, num_tokens=1): """Sample `num_tokens` tokens from the model, given the previous logits. Also return the logits of the sampled tokens. Arguments: input_ids: (batch, seqlen) Return: tokens: (batch, num_tokens) scores: (batch, num_tokens), which contains @previous_logits and the logits of the next (num_tokens - 1) tokens. The logits of the last token isn't computed. """ assert num_tokens >= 1 sequences, scores = [input_ids], [] for i in range(num_tokens): scores.append(get_logits_fn(sequences[-1], inference_params)[:, -1]) inference_params.seqlen_offset += sequences[-1].shape[1] sequences.append(sample_fn(scores[-1]).unsqueeze(1)) return torch.cat(sequences[1:], dim=1), torch.stack(scores, dim=1) sampling_kwargs = dict(top_k=top_k, top_p=top_p, temperature=temperature) sample_fn = partial(sample, **sampling_kwargs) get_logits_main = partial(get_logits, model=model, cg=cg) get_logits_draft = partial(get_logits, model=model_draft, cg=cg) sample_tokens_main = partial( sample_tokens, get_logits_fn=get_logits_main, sample_fn=sample_fn, inference_params=inference_params, ) sample_tokens_draft = partial( sample_tokens, get_logits_fn=get_logits_draft, sample_fn=sample_fn, inference_params=inference_params_draft, ) if debug: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() start = time.time() sequences, scores = [input_ids], [] num_main_model_calls = 0 num_draft_tokens = 0 num_accepted_tokens_history = [] if seqlen_og >= max_length - 1: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(input_ids, num_tokens=1) sequences.append(tokens) scores.append(scores_new) else: # Sample from draft model, which produces @n_spec_tokens, and @model # will then use to produce between 1 and 1 + @n_spec_tokens tokens. # We want seqlen_og + 1 + @n_spec_tokens to be <= @max_length. n_spec_tokens = min(speculative_lookahead, max_length - seqlen_og - 1) tokens_draft, scores_draft = sample_tokens_draft(input_ids, num_tokens=n_spec_tokens) num_draft_tokens += n_spec_tokens if debug: scores_draft_ref = model_draft( torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) # Evaluate the draft tokens with the model logits = get_logits_main( torch.cat([input_ids, tokens_draft], dim=1), inference_params, num_last_tokens=n_spec_tokens + 1, ) num_main_model_calls += 1 if debug: logits_ref = model( torch.cat([input_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((logits - logits_ref).abs().max()) # breakpoint() tokens, num_generated_tokens = sample_speculative( logits, scores_draft, tokens_draft, **sampling_kwargs ) num_accepted_tokens_history.append(num_generated_tokens - 1) if debug: print(tokens) print(num_generated_tokens) # breakpoint() # TODO: we're using the fact that batch_size == 1 # TODO: check eos_token_id sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) # Note that @model has not evaluated the last sampled token yet, so we'll need to pass # that in the next time we call @model. num_generated = num_generated_tokens[0].item() inference_params.seqlen_offset = seqlen_og + num_generated - 1 inference_params_draft.seqlen_offset = ( inference_params.seqlen_offset - 1 if num_generated > 1 else inference_params.seqlen_offset ) if debug: cur_ids = torch.cat([input_ids, sequences[-1]], dim=1) scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits print((scores[-1] - scores_ref[:, :-1]).abs().max()) # breakpoint() while True: # seqlen_offset is total length generated - 1 if inference_params.seqlen_offset >= max_length - 1: break if inference_params.seqlen_offset >= max_length - 2: # Don't do speculative sampling, just sample 1 token from the model tokens, scores_new = sample_tokens_main(sequences[-1][:, -1:], num_tokens=1) sequences.append(tokens) scores.append(scores_new) break # Sample from draft model n_spec_tokens = min( speculative_lookahead, max_length - inference_params_draft.seqlen_offset - 2 ) # If the main model accepts all the draft tokens, plus it samples one new token, # then at the next iteration the draft model need to evaluate the logits of the last draft # token and the logits of the newly sampled token. So here we pass in the last 2 tokens # of sequences[-1]. # This exception is when the main model rejects all the draft tokens, in which case we # will only have 1 token to pass in. tokens_draft, scores_draft = sample_tokens_draft( sequences[-1][:, -2:], num_tokens=n_spec_tokens ) num_draft_tokens += n_spec_tokens if debug: scores_draft_ref = model_draft( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((scores_draft - scores_draft_ref[:, :-1]).abs().max()) # breakpoint() # Evaluate the draft tokens with the model logits = get_logits_main( torch.cat([sequences[-1][:, -1:], tokens_draft], dim=1), inference_params, num_last_tokens=n_spec_tokens + 1, ) # (batch, n_spec_tokens + 1, vocab_size) num_main_model_calls += 1 if debug: logits_ref = model( torch.cat([cur_ids, tokens_draft], dim=1), num_last_tokens=n_spec_tokens + 1 ).logits print((logits - logits_ref).abs().max()) # breakpoint() tokens, num_generated_tokens = sample_speculative( logits, scores_draft, tokens_draft, **sampling_kwargs ) num_accepted_tokens_history.append(num_generated_tokens - 1) if debug: print(tokens) print(num_generated_tokens) # breakpoint() sequences.append(tokens[:1, : num_generated_tokens[0]]) scores.append(logits[:1, : num_generated_tokens[0]]) # We've evaluated 1 token from sequences[-1][:, -1:] above, plus # num_generated_tokens[0].item() - 1 tokens from the draft model. num_generated = num_generated_tokens[0].item() inference_params.seqlen_offset += num_generated inference_params_draft.seqlen_offset = ( inference_params.seqlen_offset - 1 if num_generated > 1 else inference_params.seqlen_offset ) if debug: cur_ids = torch.cat([cur_ids, sequences[-1]], dim=1) scores_ref = model(cur_ids, num_last_tokens=num_generated_tokens[0].item() + 1).logits print((scores[-1] - scores_ref[:, :-1]).abs().max()) # breakpoint() if enable_timing: if tensor_parallel > 1: torch.distributed.barrier() torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") print(f"Number of calls to main model: {num_main_model_calls}") print( f"Acceptance rate: {torch.cat(num_accepted_tokens_history).sum().item() / num_draft_tokens * 100:.2f}%" ) sequences = torch.cat(sequences, dim=1) scores = torch.cat(scores, dim=1) if debug: scores_ref = model(sequences).logits print((scores - scores_ref[:, seqlen_og - 1 : -1]).abs().max()) output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls(sequences=sequences, scores=scores) class GenerationMixin: def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): raise NotImplementedError def generate( self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, **kwargs, ): output = decode( input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs ) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences def allocate_inference_cache( max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], device, dtype=torch.float16, ): assert dtype in [torch.float16, torch.bfloat16, torch.float32] kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) if isinstance(layers, int): layers = range(layers) return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} @dataclass class DecodingCGCache: max_batch_size: int = 0 max_seqlen: int = 0 device = None dtype = None callables: dict = field(default_factory=dict) mempool = None inference_params: Optional[InferenceParams] = None run: Optional[Callable] = None @torch.inference_mode() def update_graph_cache( model, cache, batch_size, seqlen_og, max_seqlen, decoding_seqlens=(1,), tensor_parallel=1, dtype=None, n_warmups=2, ): if cache is None: cache = DecodingCGCache() param_example = next(iter(model.parameters())) device = param_example.device if dtype is None: dtype = param_example.dtype if ( (device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size or max_seqlen > cache.max_seqlen ): # Invalidate the cache cache.callables = {} cache.mempool = None cache.inference_params = None gc.collect() cache.device, cache.dtype = device, dtype cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen if hasattr(model, "allocate_inference_cache"): inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) else: headdim = getattr( model.config, "head_dim", model.config.hidden_size // model.config.num_attention_heads, ) inf_cache = allocate_inference_cache( batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, model.config.num_hidden_layers, device, dtype, ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( max_seqlen=max_seqlen, max_batch_size=batch_size, seqlen_offset=seqlen_og, key_value_memory_dict=inf_cache, lengths_per_sample=lengths_per_sample, ) cache.mempool = torch.cuda.graphs.graph_pool_handle() for decoding_seqlen in decoding_seqlens: if (batch_size, decoding_seqlen) not in cache.callables: cache.callables[batch_size, decoding_seqlen] = capture_graph( model, cache.inference_params, batch_size, max_seqlen, decoding_seqlen=decoding_seqlen, mempool=cache.mempool, n_warmups=n_warmups, ) def dispatch(input_ids, position_ids, seqlen): batch_size, decoding_seqlen = input_ids.shape[:2] return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) cache.run = dispatch cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing return cache def capture_graph( model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 ): device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) seqlen_offset_og = inference_params.seqlen_offset inference_params.seqlen_offset = max_seqlen - decoding_seqlen inference_params.lengths_per_sample[:] = inference_params.seqlen_offset # Warmup before capture s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(n_warmups): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, # which requires that graph launch and non-captured launch to not overlap (I think, # that's how I interpret the documentation). I'm not sure if this is required. if torch.distributed.is_initialized(): torch.distributed.barrier() torch.cuda.current_stream().wait_stream(s) # Captures the graph # To allow capture, automatically sets a side stream as the current stream in the context graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, pool=mempool): logits = model( input_ids, position_ids=position_ids, inference_params=inference_params, num_last_tokens=decoding_seqlen, ).logits def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen input_ids.copy_(new_input_ids) position_ids.copy_(new_position_ids) graph.replay() return logits.clone() inference_params.seqlen_offset = seqlen_offset_og return run ================================================ FILE: flash_attn/utils/library.py ================================================ # Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py # The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema. from typing import Optional, Callable, Iterable, Union from torch.library import custom_op, CustomOpDef from torch._library.triton import set_wrap_triton_enabled def triton_op( name: str, fn: Optional[Callable] = None, /, *, mutates_args: Union[str, Iterable[str]], schema: Optional[str] = None, # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False, # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator # and so inductor can't trace inside. allow_decomposition=True, ) -> Callable: def dec(fn: Callable[..., object]) -> CustomOpDef: def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] # Optimization: we're passing regular Tensors into the triton kernel, so # no need to go through HOP dispatch with set_wrap_triton_enabled(False): return fn(*args, **kwargs) result = custom_op( name, backend_fn, mutates_args=mutates_args, # This is the only difference with the PyTorch implementation schema=schema, ) from torch._subclasses.functional_tensor import FunctionalTensorMode # We require that the user pass us a function that is make_fx traceable, # so we can just register it as the Fake/meta kernel. result.register_fake(fn) if allow_decomposition: # We decompose the operator when FunctionalTensorMode is active. # The goal is to decompose the operator in AOTDispatcher. # - With torch.compile, this means that the backend (usually Inductor) # can see a call to the triton kernel(s) and so it can directly optimize # them by inlining them into the lowering process. def functional_decomp( # type: ignore[no-untyped-def] mode, op, types, args, kwargs ): from torch.export._trace import custom_triton_ops_decomposition_disabled if custom_triton_ops_decomposition_disabled(): return mode.__torch_dispatch__(op, types, args, kwargs) else: with mode: return fn(*args, **kwargs) result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result if fn is None: return dec else: return dec(fn) ================================================ FILE: flash_attn/utils/pretrained.py ================================================ import os from functools import partial import torch from safetensors.torch import load_file as safe_load_file from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ) from transformers.utils.hub import cached_file, get_checkpoint_shard_files def state_dict_from_pretrained(model_name, device=None, dtype=None): # If not fp32, then we don't want to load directly to the GPU mapped_device = "cpu" if dtype not in [torch.float32, None] else device is_sharded = False load_safe = False resolved_archive_file = None weights_path = os.path.join(model_name, WEIGHTS_NAME) weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME) safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME) safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME) if os.path.isfile(weights_path): resolved_archive_file = cached_file( model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False ) elif os.path.isfile(weights_index_path): resolved_archive_file = cached_file( model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False ) is_sharded = True elif os.path.isfile(safe_weights_path): resolved_archive_file = cached_file( model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False ) load_safe = True elif os.path.isfile(safe_weights_index_path): resolved_archive_file = cached_file( model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False ) is_sharded = True load_safe = True else: # Try loading from HF hub instead of from local files resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) if resolved_archive_file is None: resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False) if resolved_archive_file is not None: is_sharded = True if resolved_archive_file is None: raise EnvironmentError(f"Model name {model_name} was not found.") if load_safe: loader = partial(safe_load_file, device=mapped_device) else: loader = partial(torch.load, map_location=mapped_device) if is_sharded: # resolved_archive_file becomes a list of files that point to the different # checkpoint shards in this case. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( model_name, resolved_archive_file ) state_dict = {} for sharded_file in resolved_archive_file: state_dict.update(loader(sharded_file)) else: state_dict = loader(resolved_archive_file) # Convert dtype before moving to GPU to save memory if dtype is not None: state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} state_dict = {k: v.to(device=device) for k, v in state_dict.items()} return state_dict ================================================ FILE: flash_attn/utils/testing.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math from typing import Optional import torch from einops import rearrange, repeat from flash_attn.bert_padding import pad_input, unpad_input def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) if zero_lengths: # Generate zero-lengths every 5 batches and the last batch. for i in range(batch_size): if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( k, key_padding_mask, key_unused_mask ) v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(None, None), sink_token_length=0, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] is None: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), ) def construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk # Subtract remainder instead of divide and then multiply to take care of negative values col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk return torch.logical_or( col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(None, None), attention_chunk=0, sink_token_length=0, learnable_sink: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, intermediate_dtype=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) v: (batch_size, seqlen_k, nheads, head_dim_v) qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None if q_descale is not None: q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) q = (q.float() * q_descale).to(q.dtype) qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] dv = v.shape[-1] softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if qv is not None: scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) local_mask = None if window_size[0] is not None or window_size[1] is not None: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, sink_token_length, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) if attention_chunk > 0: chunk_mask = construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: scores_fp32 = scores.to(torch.float32) logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) learnable_sink = rearrange(learnable_sink, "h -> h 1 1") logits_or_sinks_max = torch.maximum(learnable_sink, logits_max) unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(learnable_sink - logits_or_sinks_max) attention = (unnormalized_scores / normalizer).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) # Without this we might get NaN in dv if key_padding_mask is not None: attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) # Some rows might be completely masked out so we fill them with zero instead of NaN if local_mask is not None: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention if intermediate_dtype is not None: attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) ================================================ FILE: flash_attn/utils/torch.py ================================================ import torch from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated) ================================================ FILE: hopper/__init__.py ================================================ __version__ = "3.0.0" ================================================ FILE: hopper/benchmark_attn.py ================================================ from collections import namedtuple from functools import partial import math import os from typing import NamedTuple import torch import torch.nn as nn import torch.nn.functional as F import time try: import cudnn except ImportError: cudnn = None # cudnn = None Timing = NamedTuple('timing', [('mean', float)]) from einops import rearrange, repeat # from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn_interface import flash_attn_func as flash_attn_func_v3 # from flash_attn_interface import flash_attn_with_kvcache as flash_attn_func_v3 from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 from triton.testing import do_bench try: from triton_fused_attention import attention as triton_attention except ImportError: triton_attention = None triton_attention = None DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): # # Warmup # for _ in range(5): # func(*args, **kwargs) # time.sleep(1) # return benchmark_forward(func, *args, **kwargs, repeats=repeats, verbose=verbose, desc=desc)[1] # s = torch.cuda.Stream() # s.wait_stream(torch.cuda.current_stream()) # with torch.cuda.stream(s): # for _ in range(2): # out = func(*args, **kwargs) # torch.cuda.current_stream().wait_stream(s) # graph = torch.cuda.CUDAGraph() # with torch.cuda.graph(graph): # out = func(*args, **kwargs) # time_f = benchmark_forward(lambda: graph.replay(), repeats=repeats, verbose=verbose, desc=desc) # # return time_f[1].mean # return time_f[1] return Timing(do_bench(lambda: func(*args, **kwargs), warmup=3, rep=repeats) * 1e-3) def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: if window_size == (-1, -1): avg_seqlen = seqlen_k else: row_idx = torch.arange(seqlen_q, device='cuda') col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) col_right = torch.minimum(row_idx + seqlen_k - seqlen_q + window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) def convert_to_cudnn_type(torch_type): if torch_type == torch.float16: return cudnn.data_type.HALF elif torch_type == torch.bfloat16: return cudnn.data_type.BFLOAT16 elif torch_type == torch.float32: return cudnn.data_type.FLOAT elif torch_type == torch.int32: return cudnn.data_type.INT32 elif torch_type == torch.int64: return cudnn.data_type.INT64 else: raise ValueError("Unsupported tensor data type.") def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1): b, nheads, seqlen_q, headdim = q.shape _, nheads_k, seqlen_k, _ = k.shape assert v.shape == (b, nheads_k, seqlen_k, headdim) assert cudnn is not None, 'CUDNN is not available' q_gpu, k_gpu, v_gpu = q, k, v o_gpu = torch.empty_like(q_gpu) stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device) graph = cudnn.pygraph( io_data_type=convert_to_cudnn_type(q.dtype), intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) q = graph.tensor_like(q_gpu.detach()) k = graph.tensor_like(k_gpu.detach()) v = graph.tensor_like(v_gpu.detach()) o, stats = graph.sdpa( name="sdpa", q=q, k=k, v=v, is_inference=False, attn_scale=1.0 / math.sqrt(headdim), # use_causal_mask_bottom_right=causal or window_size_left >= 0, use_causal_mask=causal or window_size_left >= 0, sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None, ) o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() variant_pack = { q: q_gpu, k: k_gpu, v: v_gpu, o: o_gpu, stats: stats_gpu, } workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) def run(*args, **kwargs): graph.execute(variant_pack, workspace) return o_gpu return run def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=-1): b, nheads, seqlen_q, headdim = q.shape _, nheads_k, seqlen_k, _ = k.shape assert v.shape == (b, nheads_k, seqlen_k, headdim) assert g.shape == (b, nheads, seqlen_q, headdim) assert o.shape == (b, nheads, seqlen_q, headdim) assert lse.shape == (b, nheads, seqlen_q, 1) assert cudnn is not None, 'CUDNN is not available' q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g dq_gpu = torch.empty_like(q_gpu) dk_gpu = torch.empty_like(k_gpu) dv_gpu = torch.empty_like(v_gpu) graph = cudnn.pygraph( io_data_type=convert_to_cudnn_type(q.dtype), intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) q = graph.tensor_like(q_gpu.detach()) k = graph.tensor_like(k_gpu.detach()) v = graph.tensor_like(v_gpu.detach()) o = graph.tensor_like(o_gpu.detach()) g = graph.tensor_like(g_gpu.detach()) stats = graph.tensor_like(lse.detach()) dq, dk, dv = graph.sdpa_backward( name="sdpa_backward", q=q, k=k, v=v, o=o, dO=g, stats=stats, attn_scale=1.0 / math.sqrt(headdim), # use_causal_mask_bottom_right=causal or window_size_left >= 0, use_causal_mask=causal or window_size_left >= 0, sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None, ) dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride()) dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride()) dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride()) graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() variant_pack = { q: q_gpu, k: k_gpu, v: v_gpu, o: o_gpu, g: g_gpu, stats: lse, dq: dq_gpu, dk: dk_gpu, dv: dv_gpu, } workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) def run(*args, **kwargs): graph.execute(variant_pack, workspace) return dq_gpu, dk_gpu, dv_gpu return run torch.manual_seed(0) repeats = 10 dropout_p = 0.0 causal = False dtype = torch.bfloat16 # dtype = torch.float8_e4m3fn dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype device = 'cuda' verbose = True varlen = False page_size = None softcap = 0.0 V_colmajor = False deterministic = False batch_size = 2 # seqlen = 2048 seqlen = 8192 # seqlen = 4096 # seqlen = 2047 dim = 2048 # headdim = 128 # headdim = 64 headdim = 256 # for headdim in [64, 128, 256]: # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] # bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] # bs_seqlen_vals = [(32, 512), (16, 1024)] # bs_seqlen_vals = [(2, 64 * 132)] bs_seqlen_vals = [(2, 8192)] # bs_seqlen_vals = [(1, 16 * 1024)] time_f = {} time_b = {} # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: for headdim in [128]: nheads = dim // headdim # nheads = 128 # headdim = 64 # batch_size = 64 # seqlen = 512 # nheads = 8 # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 # nheads_kv = 1 headdim_v = headdim # headdim_v = 512 has_qv = headdim == 64 and headdim_v == 512 # has_qv = False for batch_size, seqlen in bs_seqlen_vals: num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) pack_gqa = None # seqlen_q = 64 seqlen_q = seqlen leftpad_k = None # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor qv = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen) if has_qv else None # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen # cu_seqlens_q = torch.tensor([0, 248, 249, 250, 251, 252, 253, 254, 255, 256], device=device, dtype=torch.int32) # q_unpad = q_unpad[:256] # seqlen_q = 256 # cu_seqlens_q = torch.tensor([0, 376, 377, 378, 379, 380, 381, 382, 383, 384], device=device, dtype=torch.int32) # q_unpad = q_unpad[:384] # seqlen_q = 384 if page_size is not None: assert seqlen % page_size == 0 k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]] page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), "(b s) -> b s", s=seqlen // page_size) else: page_table = None for causal in [False, True]: # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: if not varlen: m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') else: m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean time.sleep(1) if not varlen: _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav2') else: _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav2') time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]] time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m3 = time_fwd(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') time_f[(causal, headdim, batch_size, seqlen), "Triton"] = m3.mean # if causal: # triton bwd only works w causal for now # time.sleep(1) # _, m3b = benchmark_backward(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') # time_b[(causal, headdim, batch_size, seqlen), "Triton"] = m3b.mean # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True) if cudnn is not None: # if False: if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean time.sleep(1) m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean # pytorch_profiler(cudnn_spda, backward=False) # pytorch_profiler(cudnn_spda_bwd, backward=False) time.sleep(1) if not varlen: # m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, qv=qv, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa) else: m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: time.sleep(1) if not varlen: _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') else: _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav3') time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean # time.sleep(1) # if not varlen: # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True) # else: # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS') # if causal: # print(f'Triton bwd: {m3b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m3b.mean * 1e-12):.1f} TFLOPS') if cudnn is not None: print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') if dtype != torch.float8_e4m3fn and headdim == headdim_v and not DISABLE_BACKWARD: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') # print(time_f) # print(time_b) # import pickle # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_cudnn_triton_20241208.plk', 'wb') as fp: # with open(f'flash3_attn_time_h100_fa3_20250313.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_fa3_fp8_20250313.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_fp8_hdim{headdim}.plk', 'wb') as fp: # # with open(f'flash3_attn_time_h100_hdim{headdim}_1031.plk', 'wb') as fp: # pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL) ================================================ FILE: hopper/benchmark_flash_attention_fp8.py ================================================ # Install the newest triton version with # pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python" import pickle import math import time import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined from flash_attn import flash_attn_qkvpacked_func from flash_attn_interface import flash_attn_func, _flash_attn_forward try: from triton_fused_attention import attention as attention_triton except ImportError: attention_triton = None try: import xformers.ops as xops except ImportError: xops = None try: import cudnn except ImportError: cudnn = None def convert_to_cudnn_type(torch_type): if torch_type == torch.float16: return cudnn.data_type.HALF elif torch_type == torch.bfloat16: return cudnn.data_type.BFLOAT16 elif torch_type == torch.float32: return cudnn.data_type.FLOAT elif torch_type == torch.int32: return cudnn.data_type.INT32 elif torch_type == torch.int64: return cudnn.data_type.INT64 elif torch_type == torch.float8_e4m3fn: return cudnn.data_type.FP8_E4M3 elif torch_type == torch.float8_e5m2: return cudnn.data_type.FP8_E5M2 else: raise ValueError("Unsupported tensor data type.") def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False): b, _, _, nheads, headdim = qkv.shape assert cudnn is not None, 'CUDNN is not available' o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) o_gpu_transposed = torch.as_strided( o_gpu, [b, nheads, seqlen_q, headdim], [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], ) stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=qkv.device) amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) graph = cudnn.pygraph( io_data_type=convert_to_cudnn_type(qkv.dtype), intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, ) new_q = torch.as_strided( qkv, [b, nheads, seqlen_q, headdim], [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], storage_offset=0, ) q = graph.tensor( name = "Q", dim = list(new_q.shape), stride = list(new_q.stride()), data_type=convert_to_cudnn_type(qkv.dtype) ) new_k = torch.as_strided( qkv, [b, nheads, seqlen_k, headdim], [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], storage_offset=nheads * headdim, ) k = graph.tensor( name = "K", dim = list(new_k.shape), stride = list(new_k.stride()), data_type=convert_to_cudnn_type(qkv.dtype) ) new_v = torch.as_strided( qkv, [b, nheads, seqlen_k, headdim], [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], storage_offset=nheads * headdim * 2, ) v = graph.tensor( name = "V", dim = list(new_v.shape), stride = list(new_v.stride()), data_type=convert_to_cudnn_type(qkv.dtype) ) def get_default_scale_tensor(): return graph.tensor( dim = [1, 1, 1, 1], stride = [1, 1, 1, 1], data_type=cudnn.data_type.FLOAT ) default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") descale_q = get_default_scale_tensor() descale_k = get_default_scale_tensor() descale_v = get_default_scale_tensor() descale_s = get_default_scale_tensor() scale_s = get_default_scale_tensor() scale_o = get_default_scale_tensor() o, _, amax_s, amax_o = graph.sdpa_fp8( q=q, k=k, v=v, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v, descale_s=descale_s, scale_s=scale_s, scale_o=scale_o, is_inference=True, attn_scale=1.0 / math.sqrt(headdim), use_causal_mask=causal, name="sdpa", ) o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) # stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) graph.validate() graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() variant_pack = { q: new_q, k: new_k, v: new_v, descale_q: default_scale_gpu, descale_k: default_scale_gpu, descale_v: default_scale_gpu, descale_s: default_scale_gpu, scale_s: default_scale_gpu, scale_o: default_scale_gpu, o: o_gpu_transposed, amax_s: amax_s_gpu, amax_o: amax_o_gpu, } workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) def run(*args, **kwargs): graph.execute(variant_pack, workspace) return o_gpu, amax_o_gpu return run def attention_pytorch(qkv, dropout_p=0.0, causal=True): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) dropout_p: float Output: output: (batch_size, seqlen, nheads, head_dim) """ batch_size, seqlen, _, nheads, d = qkv.shape q, k, v = qkv.unbind(dim=2) q = rearrange(q, 'b t h d -> (b h) t d') k = rearrange(k, 'b s h d -> (b h) d s') softmax_scale = 1.0 / math.sqrt(d) # Preallocate attn_weights for `baddbmm` scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), '(b h) t s -> b h t s', h=nheads) if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1) attention_drop = F.dropout(attention, dropout_p) output = torch.einsum('bhts,bshd->bthd', attention_drop , v) return output.to(dtype=qkv.dtype) def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): assert mode in ["fwd", "bwd", "fwd_bwd"] f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) def efficiency(flop, time): return (flop / time / 10**12) if not math.isnan(time) else 0.0 def time_fwd(func, *args, **kwargs): time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark time_f = benchmark_forward(func, *args, **kwargs) return time_f[1].mean torch.manual_seed(0) repeats = 30 device = 'cuda' # dtype = torch.float16 dtype = torch.float8_e4m3fn # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)] bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)] # bs_seqlen_vals = [(4, 4096), (2, 8192), (1, 8192 * 2)] # bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)] causal_vals = [False, True] headdim_vals = [64, 128, 256] dim = 2048 # dim = 256 dropout_p = 0.0 methods = (["Pytorch", "Flash3"] + (["cuDNN"] if cudnn is not None else []) # + (["Triton"] if attention_triton is not None else []) # + (["xformers.c"] if xops is not None else []) # + (["xformers.f"] if xops is not None else []) ) time_f = {} time_b = {} time_f_b = {} speed_f = {} speed_b = {} speed_f_b = {} for causal in causal_vals: for headdim in headdim_vals: for batch_size, seqlen in bs_seqlen_vals: torch.cuda.empty_cache() config = (causal, headdim, batch_size, seqlen) nheads = dim // headdim q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16, requires_grad=False) for _ in range(3)] qkv = torch.stack([q, k, v], dim=2) qkv = qkv.to(torch.bfloat16) f = time_fwd(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False) time_f[config, "Pytorch"] = f res_baseline = attention_pytorch(qkv, dropout_p, causal=causal) if attention_triton is not None: q_transposed = q.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) k_transposed = k.transpose(1, 2).contiguous().to(torch.float8_e4m3fn) v_transposed = v.transpose(1, 2).contiguous().permute(0, 1, 3, 2).to(torch.float8_e4m3fn) scale = 1 / math.sqrt(headdim) f = time_fwd( attention_triton, q_transposed, k_transposed, v_transposed, causal, scale, repeats=5, verbose=False, desc='Triton' ) f = time_fwd( attention_triton, q_transposed, k_transposed, v_transposed, causal, scale, repeats=repeats, verbose=False, desc='Triton' ) time_f[config, "Triton"] = f res = attention_triton( q_transposed, k_transposed, v_transposed.permute(0, 1, 3, 2), causal, scale ).half().transpose(1, 2) torch.testing.assert_close(res, res_baseline, atol=0.5, rtol=0.5) # out = torch.empty_like(q) q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) softmax_scale = q.shape[-1] ** (-0.5) descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda') # f = time_fwd(flash_attn_func, q, k, v, causal=causal, repeats=repeats, verbose=False) f = time_fwd( _flash_attn_forward, q, k, v, softmax_scale, causal=causal, window_size=(-1,-1), descale_q=descale_q, descale_k=descale_k, descale_v=descale_v, repeats=repeats, verbose=False ) # res = flash_attn_func(q, k, v, causal=causal) # torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05) time_f[config, "Flash3"] = f if cudnn is not None: qkv_fp8 = qkv.to(dtype) time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark f = time_fwd( cudnn_spda_setup( qkv_fp8, seqlen, seqlen, causal=causal ), repeats=repeats, verbose=False ) time_f[config, "cuDNN"] = f # res, amax_o = cudnn_spda_setup( # qkv_fp8, seqlen, seqlen, # causal=causal # )() # res = res.half() # TODO: CUDNN has numerics issues when # num_heads=16, dim=128, seq_len=1024, batch_size=2 # or larger sizes. # res_cpu = res.cpu().reshape(-1) # res_baseline_cpu = res_baseline.cpu().reshape(-1) # print(amax_o) # print(res) # print(res_baseline) # for i in range(len(res_cpu)): # item = res_cpu[i] # item_baseline = res_baseline_cpu[i] # if abs(item - item_baseline) > 0.5: # print(i) # print(item) # print(item_baseline) # torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05) print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###") for method in methods: speed_f[config, method] = efficiency( flops(batch_size, seqlen, headdim, nheads, causal, mode="fwd"), time_f[config, method] ) #print (time_f[config,method]) print( f"{method} fwd: {speed_f[config, method]:.2f} TFLOPs/s, {time_f[config, method] * 1e3} ms, " ) # with open('flash3_attn_time.plk', 'wb') as fp: # pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL) ================================================ FILE: hopper/benchmark_mla_decode.py ================================================ # Copyright (c) 2025, Ted Zadouri, Tri Dao. # We recommend locking GPU clocks before running the benchmark to ensure consistent results. # This can be done using the following commands (1830 MHz is the clock for H100): # sudo nvidia-smi -i 0 -pm 1 # sudo nvidia-smi -i 0 --lock-gpu-clocks 1830,1830 # See more here: https://github.com/triton-lang/triton/blob/d9f10ebdc5da53f73eb852fde73d8d7d80b679d1/python/triton/testing.py#L487 import time import torch import torch.nn.functional as F from triton.testing import do_bench, do_bench_cudagraph from einops import rearrange from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata try: from flash_mla import flash_mla_with_kvcache, get_mla_metadata except ImportError: flash_mla_with_kvcache, get_mla_metadata = None, None try: from flash_attn.utils.benchmark import pytorch_profiler except ImportError: pytorch_profiler = None device = "cuda" dtype = torch.bfloat16 seqlen = 8192 seqlen_q = 1 # nheads_q = 16 nheads_q = 128 use_bench_cudagraph = False attn_variants = ["mha", "gqa", "mqa", "mla", "gla"] # for attn_variant in attn_variants: for attn_variant in attn_variants[3:5]: nheads_kv = nheads_q if attn_variant == "mha" else (max(nheads_q // 8, 1) if attn_variant == "gqa" else (1 if attn_variant == "mla" else 2)) headdim = 64 if attn_variant in ["mla", "gla"] else 128 headdim_v = 512 if attn_variant == "mla" else (256 if attn_variant == "gla" else headdim) has_qv = headdim == 64 and headdim_v > 64 # page_size = None page_size = 64 if attn_variant in ["mla", "gla"] else 128 should_run_flashmla = attn_variant == "mla" and page_size == 64 and flash_mla_with_kvcache is not None torch.manual_seed(0) batch_size = 128 cache_seqlens = None # cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([seqlen - 1, 1024, 1024, 1024], device=device, dtype=torch.int32) # cache_seqlens = torch.tensor([1024] * batch_size, device=device, dtype=torch.int) # cache_seqlens = torch.tensor([4500, 45000, 1800, 1800], dtype=torch.int32, device=device) print(f"\n{attn_variant.upper()}, nheads_q = {nheads_q}, nheads_kv = {nheads_kv}, headdim = {headdim}, headdim_v = {headdim_v}, page_size = {page_size}") for seqlen in [s * 1024 for s in [1, 2, 4, 8, 16, 32, 64]]: # for seqlen in [s * 1024 for s in [8]]: cache_seqlens = torch.tensor([seqlen] * batch_size, device=device, dtype=torch.int) num_splits = 0 q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, dtype=dtype, device=device) try: v_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, dtype=dtype, device=device) k_cache = torch.randn(batch_size, seqlen, nheads_kv, headdim, dtype=dtype, device=device) if page_size is not None: assert seqlen % page_size == 0 k_cache, v_cache = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k_cache, v_cache]] page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32), "(b s) -> b s", s=seqlen // page_size) else: page_table = None except torch.OutOfMemoryError: continue qv = torch.randn(batch_size, seqlen_q, nheads_q, headdim_v, dtype=dtype, device=device) if has_qv else None # Precomputing this saves ~2us scheduler_metadata = get_scheduler_metadata( batch_size, seqlen_q, seqlen, nheads_q, nheads_kv, headdim, cache_seqlens, q.dtype, headdim_v=headdim_v, page_size=page_size, causal=True ) # scheduler_metadata = None # breakpoint() fn0 = lambda: flash_attn_with_kvcache(q, k_cache, v_cache, cache_seqlens=cache_seqlens, num_splits=num_splits, qv=qv, page_table=page_table, causal=True, scheduler_metadata=scheduler_metadata) time.sleep(1) # to avoid power throttling # Time in ms if not use_bench_cudagraph: t0 = do_bench(fn0, warmup=1, rep=10) else: torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready with torch.cuda.stream(torch.cuda.Stream()): t0 = do_bench_cudagraph(fn0, rep=10) # exit(0) if should_run_flashmla: # Separate out the preprocessing since this can be done once and reused for all layers mla_metadata = get_mla_metadata(cache_seqlens, seqlen_q * nheads_q // nheads_kv, nheads_kv) q_concat = torch.concat([q, qv], dim=-1) if has_qv else q kv_cache_concat = torch.concat([v_cache, k_cache], dim=-1) fn1 = lambda: flash_mla_with_kvcache(q_concat, kv_cache_concat, page_table, cache_seqlens, headdim_v, *mla_metadata, causal=True) time.sleep(1) # to avoid power throttling if not use_bench_cudagraph: t1 = do_bench(fn1, warmup=1, rep=10) else: torch.cuda.synchronize() # Gotta wait, otherwise e.g. k_cache might not be ready with torch.cuda.stream(torch.cuda.Stream()): t1 = do_bench_cudagraph(fn1, rep=10) total_seqlen = seqlen * batch_size if cache_seqlens is None else cache_seqlens.sum().item() mem_io = total_seqlen * nheads_kv * (headdim + headdim_v) * 2 + q.numel() * 2 + (qv.numel() * 2 if has_qv else 0) + q.numel() * headdim_v // headdim * 2 # last term is for the output flops = seqlen_q * total_seqlen * nheads_q * (headdim + headdim_v * (2 if has_qv else 1)) * 2 ideal_h100_time_mem = mem_io / 3.35e12 * 1e6 ideal_h100_time_flop = flops / 989e12 * 1e6 ideal_h100_time = max(ideal_h100_time_mem, ideal_h100_time_flop) print(f"Seqlen = {seqlen}, FA3 time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t0 * 1e3:.1f} us, {mem_io * 1e-9 / (t0 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t0 * 1e-3):.0f} TFLOPS/s") if should_run_flashmla: print(f"Seqlen = {seqlen}, FlashMLA time{'' if not use_bench_cudagraph else ' w CUDA Graph'}: {t1 * 1e3:.1f} us, {mem_io * 1e-9 / (t1 * 1e-3):.0f} GB/s, {flops * 1e-12 / (t1 * 1e-3):.0f} TFLOPS/s") print(f"Arithmetic intensity: {flops / mem_io:.1f}") print(f"Ideal time: {ideal_h100_time:.0f} us") # if pytorch_profiler is not None: # time.sleep(1) # to avoid power throttling # pytorch_profiler(fn0) # if should_run_flashmla: # time.sleep(1) # to avoid power throttling # pytorch_profiler(fn1) ================================================ FILE: hopper/benchmark_split_kv.py ================================================ import torch import flash_attn import flash_attn_interface import itertools import time import math import torch.utils.benchmark as benchmark def round_up_to_power_of_2(x): if x <= 1: return 1 return 1 << (x - 1).bit_length() def timeit(fn, *args, **kwargs): torch.cuda.synchronize() # Warmup for _ in range(5): fn(*args, **kwargs) # Benchmark using PyTorch Timer t = benchmark.Timer( stmt='fn(*args, **kwargs)', globals={'fn': fn, 'args': args, 'kwargs': kwargs} ) # Measure execution time measurement = t.timeit(20) # Runs the function 20 times # measurement = t.blocked_autorange(min_run_time=1) avg_time = measurement.mean # Average time in seconds return avg_time def main(): num_sms = torch.cuda.get_device_properties( torch.cuda.current_device() ).multi_processor_count max_splits = 129 check_all_splits = True causal = True # causal = False # dtype=torch.float16 dtype=torch.bfloat16 tp_degree = 1 torch.manual_seed(42) model_configs = [ # ("Gemma-2-2B", 8, 4, 256), # ("Gemma-2-9B", 16, 8, 256), # ("Gemma-2-27B", 32, 16, 128), # ("Qwen-2.5-0.5B", 14, 2, 64), # ("Qwen-2.5-1.5B", 12, 2, 128), # ("Qwen-2.5-7B", 28, 4, 128), # ("Llama-3.1-8B", 32, 8, 128), ("Llama-3.1-70B", 64, 8, 128), # ("Mistral Large", 96, 8, 128), # ("Llama-3.1-405B", 128, 8, 128), # ("Llama-3.2-1B", 32, 8, 64), # ("Llama-3.2-3B", 24, 8, 128), # ("Nemotron-4-15B", 48, 8, 128), ] all_batch_configs = [] all_batch_configs.extend(itertools.product( # [1024, 2048, 4096, 8192, 16384, 32768, 131072], # context_seqlen # [4096, 16384, 65536], # context_seqlen [131072], # context_seqlen # [i for i in range(1, (num_sms) + 1)], # num_requests [1, 4, 8, 16], # num_requests # [1], # num_requests # [1, 4, 8, 16], # query_seqlen [1], # query_seqlen )) num_caches = max(reqs for _, reqs, _ in all_batch_configs) cache_seqlen = max(seqlen for seqlen, _, _ in all_batch_configs) for model_name, nheads_q, nheads_kv, headdim in model_configs: assert nheads_kv % tp_degree == 0 print(f"***{model_name}***") print(f"QHEADS:{nheads_q}, KVHEADS:{nheads_kv}, HEADDIM:{headdim}, TP:{tp_degree}") nheads_q //= tp_degree nheads_kv //= tp_degree k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) if check_all_splits is False: print(f"{'CONTEXT':<9}{'BSZ':<5}{'QLEN':<6}{'FA2':<10}{'FA3':<9}{'RATIO':<7}{'GB/s':<10}") for context_seqlen, num_requests, query_seqlen in all_batch_configs: bytes_kv = (context_seqlen * num_requests * nheads_kv * headdim * 4) bytes_q = (query_seqlen * num_requests * nheads_q * headdim * 4) blockH = round_up_to_power_of_2(nheads_q//nheads_kv) blockM = 128 # true for hdim 128 causal and hdim 64 blockM_div_H = blockM//blockH num_work_tiles = nheads_kv * num_requests * math.ceil(query_seqlen/blockM_div_H) q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=dtype) cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] cache_seqlens = torch.tensor( [context_seqlen] * num_requests, dtype=torch.int32, device="cuda" ) fa2_time_heuristic = timeit( flash_attn.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, ) * 1000. * 1000. # fastest_splitk_time = float("inf") # fastest_splitk = 0 # for i in range(1, max_splits): # t = timeit( # flash_attn.flash_attn_with_kvcache, # q=q, # k_cache=k_cache, # v_cache=v_cache, # cache_seqlens=cache_seqlens, # cache_batch_idx=cache_idxs, # causal=causal, # num_splits=i, # ) * 1000. * 1000. # if t < fastest_splitk_time: # fastest_splitk_time = t # fastest_splitk = i fa3_time_one_split = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=False, num_splits=1, ) * 1000. * 1000. fa3_time_gqa_heuristic = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, num_splits=0, # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. if check_all_splits: fa3_fastest_num_splits = 0 fa3_fastest_splitk_time = float("inf") for num_splits in range(1, max_splits): t = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=False, num_splits=num_splits ) * 1000. * 1000. out0 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=False, num_splits=num_splits ) out1 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=False, num_splits=1 ) max_diff = (out0 - out1).abs().max().item() mean_diff = (out0 - out1).abs().mean().item() # print (f"splits {num_splits}, out diff-max, {max_diff}, out diff-mean, {mean_diff}, time {t:.2f}") # print (f"splits {num_splits}, time {t:.2f}") if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4: print(f"Numerical error too high: Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}") if t < fa3_fastest_splitk_time: fa3_fastest_splitk_time = t fa3_fastest_num_splits = num_splits fa3_fastest_num_splits_gqa = 0 fa3_fastest_splitk_time_gqa = float("inf") for num_splits in range(1, max_splits): t = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, num_splits=num_splits ) * 1000. * 1000. out0 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, num_splits=num_splits ) out1 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, num_splits=1 ) max_diff = (out0 - out1).abs().max().item() mean_diff = (out0 - out1).abs().mean().item() # print (f"gqa splits {num_splits}, out gqa diff-max {max_diff}, out gqa diff-mean {mean_diff}, time {t:.2f}") # print (f"gqa splits {num_splits}, time {t:.2f}") if math.isnan(max_diff) or math.isnan(mean_diff) or max_diff > 2e-3 or mean_diff > 1e-4: print(f"Numerical error too high (gqa): Splits: {num_splits}, Max: {max_diff}, Mean: {mean_diff}") if t < fa3_fastest_splitk_time_gqa: fa3_fastest_splitk_time_gqa = t fa3_fastest_num_splits_gqa = num_splits efficiency = (num_work_tiles * fa3_fastest_num_splits_gqa)/num_sms heuristic_ratio = fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa # remeasure to smooth anomalies if heuristic_ratio > 1.1: fa3_time_gqa_heuristic = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, # num_splits=num_splits_select, # num_splits=1, num_splits=0, # max_seqlen_k_hint=context_seqlen ) * 1000. * 1000. fa3_fastest_splitk_time_gqa = timeit( flash_attn_interface.flash_attn_with_kvcache, q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, pack_gqa=True, num_splits=fa3_fastest_num_splits_gqa ) * 1000. * 1000. if check_all_splits is True: print( f"CONTEXT:{context_seqlen}, BSZ:{num_requests}, QLEN:{query_seqlen}, " f"FA2:{fa2_time_heuristic:.2f}, " # f"FA2 MANUAL:{fastest_splitk_time:.2f}, " # f"FA2 NUM SPLITS:{fastest_splitk}, " # f"FA3 NOGQA NOSPLIT:{fa3_time_one_split:.2f}, " # f"FA3 NOGQA SPLIT MANUAL:{fa3_fastest_splitk_time:.2f}, " # f"FA3 NOSPLIT:{fa3_time_one_split_gqa:.2f}, " f"FA3 SPLIT MANUAL:{fa3_fastest_splitk_time_gqa:.2f}, " f"FA3:{fa3_time_gqa_heuristic:.2f}, " # f"FA3 RATIO (NONSPLIT/SPLIT):{fa3_time_one_split_gqa/fa3_time_gqa_heuristic:.2f}, " # f"FA2 NUM SPLITS:{fastest_splitk}, " # f"FA3 NOGQA NUM SPLITS:{fa3_fastest_num_splits}, " f"FA3 NUM SPLITS:{fa3_fastest_num_splits_gqa}, " # f"RATIO (FA2/3):{fa2_time_heuristic/fa3_time_gqa_heuristic:.2f}, " f"RATIO:{fa3_time_gqa_heuristic/fa3_fastest_splitk_time_gqa:.2f}, " f"EFF:{efficiency:.2f}, " f"GB/s:{bytes_kv/fa3_time_gqa_heuristic * 1e-3:.2f}" ) if check_all_splits is False: print( f"{context_seqlen:<9}{num_requests:<5}{query_seqlen:<6}" f"{fa2_time_heuristic:<10.2f}{fa3_time_gqa_heuristic:<9.2f}" f"{fa2_time_heuristic/fa3_time_gqa_heuristic:<7.2f}" f"{bytes_kv/fa3_time_gqa_heuristic * 1e-3:<10.2f}" ) if __name__ == "__main__": main() ================================================ FILE: hopper/block.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once namespace flash { template struct BlockMN { static CUTLASS_DEVICE cute::tuple get_n_block_min_max( SeqlenInfo_t const& seqlen_info, int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { int const seqlen_k = seqlen_info.seqlen_k; int const seqlen_q = seqlen_info.seqlen_q; int n_block_max = cute::ceil_div(seqlen_k, kBlockN); if constexpr (Is_causal || Is_local) { int m_idx_max = (m_block + 1) * kBlockM; // TODO: check off-by-1 error if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; if (Is_local && attention_chunk_divmod.divisor > 0) { n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); } n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN)); } int n_block_min = 0; if constexpr (Is_local) { int m_idx_min = m_block * kBlockM; if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } int const n_idx = m_idx_min + seqlen_k - seqlen_q; int n_idx_left = n_idx - window_size_left; if (attention_chunk_divmod.divisor > 0) { n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); } n_block_min = std::max(int(0), n_idx_left / kBlockN); } // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } if constexpr (Split) { uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); int split_idx_actual = split_idx & 0x0000FFFF; int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual); n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split; n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max); // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } } // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } return {n_block_min, n_block_max}; } static CUTLASS_DEVICE cute::tuple get_n_block_k_new_min_max( SeqlenInfo_t const& seqlen_info, int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { auto [n_block_min, n_block_max] = get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, num_splits, window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod); int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); int const n_block_new_min = idx_k_new_min / kBlockN; int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} return {n_block_new_min, n_block_new_max}; } static CUTLASS_DEVICE cute::tuple get_m_block_min_max( SeqlenInfo_t const& seqlen_info, int const n_block, int const bidb, int const window_size_left, int const window_size_right, int const sink_token_length) { // TODO: support attention_chunk int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int m_block_max = cute::ceil_div(seqlen_q, kBlockM); if constexpr (Is_local) { if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) { m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM)); } } int m_block_min = 0; if constexpr (Is_causal || Is_local) { m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM); } return {m_block_min, m_block_max}; } // If we have separate iterations with causal or local masking at the start, where do we stop static CUTLASS_DEVICE int get_n_block_min_causal_local_mask( SeqlenInfo_t const& seqlen_info, int const m_block, int const n_block_min, int const window_size_right, cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM); int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q; int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right; if (Is_local && attention_chunk_divmod.divisor > 0) { n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx)); } return std::max(n_block_min, n_idx_right / kBlockN); } // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations static CUTLASS_DEVICE int get_n_block_min_before_local_mask( SeqlenInfo_t const& seqlen_info, int const m_block, int const n_block_min, int const window_size_left, cutlass::FastDivmod const& attention_chunk_divmod, cutlass::FastDivmod const& qhead_per_khead_divmod) { int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q; int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left; if (Is_local && attention_chunk_divmod.divisor > 0) { n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx)); } return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN)); } }; } // namespace flash ================================================ FILE: hopper/copy_sm90_bulk_reduce.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include namespace cute { //////////////////////////////////////////////////////////////////////////////////////////////////// struct SM90_BULK_REDUCE_ADD { CUTE_HOST_DEVICE static void copy(float const* smem_ptr, float * gmem_ptr, int32_t store_bytes) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n" : : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); #endif } CUTE_HOST_DEVICE static void copy(float const* smem_ptr, float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n" : : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint) : "memory"); #else CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED."); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cute ================================================ FILE: hopper/cuda_check.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #define CHECK_CUDA(call) \ do { \ cudaError_t status_ = call; \ if (status_ != cudaSuccess) { \ fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ exit(1); \ } \ } while(0) #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) #define CHECK_CUTLASS(call) \ do { \ cutlass::Status status_ = (call); \ if (status_ != cutlass::Status::kSuccess) { \ fprintf(stderr, "CUTLASS error (%s:%d): %s\n", __FILE__, __LINE__, cutlass::cutlassGetStatusString(status_)); \ exit(1); \ } \ } while(0) ================================================ FILE: hopper/epilogue_bwd.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/barrier.h" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "seqlen.h" #include "named_barrier.hpp" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveEpilogueBwd { using TileShape_MNK = TileShape_MNK_; using Element = Element_; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90; static_assert(ArchTag::kMinComputeCapability >= 80); using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads); static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int(TileShape_MNK{})) / AtomLayoutKdKV>>()); using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutdKVtTMA = decltype(cute::composition(SmemLayoutdKVTMA{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); // If we don't use TMA static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16); static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); using SmemLayoutAtomdKVSTG = decltype(composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutAtomdKV = std::conditional_t; using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutdKVt = decltype(cute::composition(SmemLayoutdKV{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{})))); using SmemCopyAtomdKV = Copy_Atom< std::conditional_t< ArchTag::kMinComputeCapability >= 90, std::conditional_t, AutoVectorizingCopyWithAssumedAlignment<128> >, Element>; static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128; static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment"); struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentdKV> smem_dk; cute::array_aligned, SmemAlignmentdKV> smem_dv; }; using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) using StridedKV = cute::Stride; using TMA_dKV = std::conditional_t< Use_TMA, decltype(make_tma_copy( GmemTiledCopydKVTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapedKV{}, StridedKV{}), SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{})), // no mcast for dKV std::nullptr_t >; // Host side kernel arguments struct Arguments { Element* ptr_dK; ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; ShapedKV const shape_dV; StridedKV const stride_dV; int const num_batch; int const num_heads_q; int* dk_semaphore; int* dv_semaphore; int const* cu_seqlens; int const* seqused; }; // Device side kernel params struct Params { Element* ptr_dK; ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; ShapedKV const shape_dV; StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); TMA_dKV tma_store_dK = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV } else { return nullptr; } }(); TMA_dKV tma_store_dV = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV } else { return nullptr; } }(); return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA) { cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor()); } } template CUTLASS_DEVICE void store(Params const& params, FrgTensorO const& tdKrdK, FrgTensorO const& tdVrdV, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, cute::tuple const& block_coord ) { auto [n_block, bidh, bidb] = block_coord; Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{})); Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{})); Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{})); Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{})); auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma); auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx); Tensor tdVrdV_out = make_tensor_like(tdVrdV); flash::convert_type_out(tdVrdV, tdVrdV_out); Tensor tdKrdK_out = make_tensor_like(tdKrdK); flash::convert_type_out(tdKrdK, tdKrdK_out); Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); } Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Make sure all WGs have finished reading K and V flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); if constexpr (Use_TMA) { cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); auto block_tma_dV = params.tma_store_dV.get_slice(_0{}); Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (cute::elect_one_sync()) { cute::copy(params.tma_store_dV, tdVsdV, tdVgdV); cute::copy(params.tma_store_dK, tdKsdK, tdKgdK); tma_store_arrive(); } } tma_store_wait<0>(); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); } else { flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); static constexpr int kBlockN = get<1>(TileShape_MNK{}); flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K) Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K) Tensor tdKVrdV = make_fragment_like(tdKVgdV); Tensor tdKVrdK = make_fragment_like(tdKVgdK); Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); #pragma unroll for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } #pragma unroll for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } // Need to check OOB when reading from smem if kBlockN isn't evenly tiled static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN); flash::copy( gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); // Construct identity layout for gdKV // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); flash::copy( gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); } } CUTLASS_DEVICE void store_tail() { // if constexpr (Use_TMA) { tma_store_wait<0>(); } } // Write 0 to dK and dV CUTLASS_DEVICE void store_zero( Params const& params, int thread_idx, cute::tuple const& block_coord ) { static constexpr int kBlockN = get<1>(TileShape_MNK{}); auto [n_block, bidh, bidb] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx); Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK); Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV); Tensor tdKVrdKV = make_fragment_like(tdKVgdK); clear(tdKVrdKV); // Construct identity layout for gdKV Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); #pragma unroll for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN ); flash::copy( gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN ); } }; template struct CollectiveEpilogueBwdGQA { using TileShape_MNK = TileShape_MNK_; using Element = ElementAccum; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90; static_assert(ArchTag::kMinComputeCapability >= 80); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp"); static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup; // Thread layout, 256 or 384 threads per row // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ using R2SLayoutAtomdKVaccum = Layout, Int>>; using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdKVaccum{}, Layout>{})); // Val layout, 4 vals per store // For Sm80 using R2GLayoutAtomdKVaccum = Layout>>; using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2GLayoutAtomdKVaccum{}, Layout>{})); // Val layout, 1 vals per store using SmemLayoutdKVaccum = Layout, Int>>; using SmemLayoutdKVaccumFlat = Layout>>; // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue. static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256); struct TensorStorageTMA : cute::aligned_struct { cute::array_aligned, SmemAlignment> smem_dkv; }; struct TensorStorageSTG { cute::array smem_dkv; }; using TensorStorage = std::conditional_t; using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) using StridedKV = cute::Stride<_1, int64_t, int64_t>; // Host side kernel arguments struct Arguments { ElementAccum* ptr_dKaccum; ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; int const num_batch; int const num_heads_q; int* dk_semaphore; int* dv_semaphore; int const* cu_seqlens; int const* seqused; }; // Device side kernel params struct Params { ElementAccum* ptr_dKaccum; ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; int* dv_semaphore; int const num_batch; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { if constexpr (Deterministic) { assert(args.dk_semaphore != nullptr); assert(args.dv_semaphore != nullptr); } return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, args.num_batch, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { } template CUTLASS_DEVICE void store(Params const& params, FrgTensorO const& tdKrdK, FrgTensorO const& tdVrdV, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, cute::tuple const& block_coord ) { auto [n_block, bidh, bidb] = block_coord; int bidh_idx_in_group; int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh); Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{}); Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{}); static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum); flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV); // Only used if !Use_TMA R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum; auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx); // Make sure all WGs have finished reading K and V, otherwise we get racy dQ // because smem_q could be changed. flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if constexpr (Use_TMA) { Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N) cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum); } int const num_batch = params.num_batch; // int const num_batch = get<2>(params.shape_dKaccum); // erroneously returns 1 for varlen int const num_head_kv = get<1>(params.shape_dKaccum); int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv; using Barrier = cutlass::GenericBarrier; // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} if constexpr (Deterministic) { Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); } // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);} if constexpr (Use_TMA) { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (thread_idx == 0) { SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); tma_store_wait<0>(); } } else { Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV); Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum); static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic))); #pragma unroll for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); } } if constexpr (Deterministic) { Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); } if constexpr (Use_TMA) { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N) cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum); } lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv; // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);} if constexpr (Deterministic) { Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group); } // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);} if constexpr (Use_TMA) { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (thread_idx == 0) { SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); tma_store_wait<0>(); } } else { Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK); Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum); static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic))); #pragma unroll for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); } } if constexpr (Deterministic) { Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv); } // // Tell warp 0 that smem_k and smem_v are ready // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); } CUTLASS_DEVICE void store_tail() { } // Write 0 to dK and dV CUTLASS_DEVICE void store_zero( Params const& params, int thread_idx, cute::tuple const& block_coord ) { // Don't need to do anything since dKaccum and dVaccum are already zero-initialized } }; } // namespace flash ================================================ FILE: hopper/epilogue_fwd.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include // For FastDivMod #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "cutlass/epilogue/collective/builders/sm90_common.inl" #include "seqlen.h" #include "named_barrier.hpp" #include "pack_gqa.h" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveEpilogueFwd { using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; using ElementPartial = float; using ArchTag = ArchTag_; static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool Use_smem = !(Split && !Varlen); static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); static_assert(sizeof(Element) <= 2); static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); static constexpr bool LargeHeadDimV = kHeadDimV > 256; using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow"); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) using StrideO = cute::Stride; using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; using EpilogueTile_MN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom; // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{}); // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment"); // struct TensorStorage : cute::aligned_struct { // cute::array_aligned : 0, SmemAlignmentO> smem_o; // }; struct TensorStorage : cute::aligned_struct<128> { cute::array_aligned : 0> smem_o; }; using TMA_O = std::conditional_t< Use_TMA_O, decltype(make_tma_copy( GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O std::nullptr_t >; // Host side kernel arguments struct Arguments { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; ElementPartial* ptr_O_partial; StrideO const stride_O_partial; float* ptr_LSE; StrideLSE const stride_LSE; float* ptr_LSE_partial; StrideLSE const stride_LSE_partial; int32_t const nheads_kv; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Device side kernel params struct Params { Element* ptr_O; ShapeO const shape_O; StrideO const stride_O; ShapeOPacked const shape_O_packed; StrideOPacked const stride_O_packed; ElementPartial* ptr_O_partial; StrideO const stride_O_partial; StrideOPacked const stride_O_partial_packed; float* ptr_LSE; StrideLSE const stride_LSE; ShapeLSEPacked const shape_LSE_packed; StrideLSEPacked const stride_LSE_packed; float* ptr_LSE_partial; StrideLSE const stride_LSE_partial; StrideLSEPacked const stride_LSE_partial_packed; cutlass::FastDivmod qhead_per_khead_divmod; TMA_O tma_store_O; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast } else { return nullptr; } }(); // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); auto const shape_O_packed = cute::conditional_return( args.shape_O, make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) ); auto const stride_O_packed = cute::conditional_return( args.stride_O, make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O)) ); auto const stride_O_partial_packed = cute::conditional_return( args.stride_O_partial, make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) ); // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) ); auto const stride_LSE_packed = cute::conditional_return( args.stride_LSE, make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE)) ); auto const stride_LSE_partial_packed = cute::conditional_return( args.stride_LSE_partial, make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial)) ); return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed, args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed, args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed, args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed, cutlass::FastDivmod(qhead_per_khead), tma_store_O, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_O) { cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor()); } } template CUTLASS_DEVICE void store(Params const& params, FrgTensorO& tOrO, FrgTensorLSE const& lse, SharedStorage& shared_storage, TiledMma tiled_mma, int thread_idx, cute::tuple const& block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; int num_splits = get<4>(params.shape_O_packed); if constexpr (Split && Varlen) { uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx } bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{}); // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO); static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4); // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion. // Otherwise we can permute after conversion. if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); } Tensor tOrO_out = make_tensor_like(tOrO); flash::convert_type_out(tOrO, tOrO_out); if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); } // Make sure all WGs have finished reading V // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with // cp.async if we need). flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); // Step 1: Write O from rmem -> smem if constexpr (Use_smem) { auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); if constexpr (Use_TMA_O) { cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } else { flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); } } else { if constexpr (ArchTag::kMinComputeCapability >= 90) { #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } } flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; using PackGQApartial_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>; Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), params.shape_LSE_packed, !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } if (!LargeHeadDimV || warp_group_idx == 0) { if constexpr (!PackGQA) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } } } else { PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0); if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); if (cute::elect_one_sync()) { cute::copy(params.tma_store_O, tOsO, tOgO); tma_store_arrive(); tma_store_wait<0>(); #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence if (!is_split) { Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOrO = make_fragment_like(tOsO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); if constexpr (ArchTag::kMinComputeCapability >= 90) { cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } Tensor tOgO = gmem_thr_copy_O.partition_D(gO); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } else { Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // We already arrived on barrier_O earlier if !Use_smem if constexpr (Use_smem) { if constexpr (ArchTag::kMinComputeCapability >= 90) { #pragma unroll for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { shared_storage.pipelines.barrier_O.arrive(cta_id); } } } if constexpr (!PackGQA) { static constexpr int kGmemElemsPerStoreDirect = 2; cute::Copy_Atom, ElementPartial> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); Tensor tOgO = thread_mma.partition_C(gOpartial); Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout())); Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int>{}); Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); #pragma unroll for (int m = 0; m < size(taccOcO_row); ++m) { if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) { #pragma unroll for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) { if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) { cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k)); } } } } } else { PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } } CUTLASS_DEVICE void store_tail() { // Don't need to do tma_store_wait<0>() here since we already did in @store } // Write 0 to output and -inf to LSE CUTLASS_DEVICE void store_zero( Params const& params, int thread_idx, cute::tuple const& block_coord ) { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; int num_splits = get<4>(params.shape_O_packed); if constexpr (Split && Varlen) { uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits int num_splits_dynamic = reinterpret_cast(num_splits_dynamic_u); num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits; split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx } bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1); flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)), params.shape_LSE_packed, !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx); Tensor gLSE = local_tile(mLSE, Shape>{}, make_coord(m_block)); static_assert(kBlockM <= NumEpilogueThreads); if (thread_idx < kBlockM) { const int row = m_block * kBlockM + thread_idx; if constexpr (!PackGQA) { if (row < seqlen_o) { mLSE(row) = -INFINITY; } } else { if (row < seqlen_o * qhead_per_khead) { int m_idx, h_idx; m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row); // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord" mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY; } } } // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used, // since it will not use the value of O if LSE is -inf. if (!is_split) { Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{}); GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row using PackGQA_t = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } } } }; } // namespace flash ================================================ FILE: hopper/flash.h ================================================ /****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ v_ptr; // The stride between rows of the Q, K and V matrices. index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride; index_t q_row_stride; index_t k_row_stride; index_t v_row_stride; index_t q_head_stride; index_t k_head_stride; index_t v_head_stride; index_t v_dim_stride; // The number of heads. int h, h_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { using index_t = int64_t; // The O matrix (output). void * __restrict__ o_ptr; void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; index_t o_row_stride; index_t o_head_stride; // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lseaccum_ptr; // For FP8 scaling float * __restrict__ q_descale_ptr; float * __restrict__ k_descale_ptr; float * __restrict__ v_descale_ptr; index_t q_descale_batch_stride; index_t q_descale_head_stride; index_t k_descale_batch_stride; index_t k_descale_head_stride; index_t v_descale_batch_stride; index_t v_descale_head_stride; // The dimensions. int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim // The scaling factors for the kernel. float scale_softmax; float softcap; // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_knew; int * __restrict__ leftpad_k; // If provided, the actual length of each q/k sequence. int *__restrict__ seqused_q; int *__restrict__ seqused_k; // The stride between rows of Oaccum. index_t oaccum_split_stride; index_t oaccum_batch_stride; index_t oaccum_row_stride; index_t oaccum_head_stride; // The stride between rows of LSEaccum. index_t lseaccum_split_stride; index_t lseaccum_batch_stride; index_t lseaccum_head_stride; // The K_new and V_new matrices. void * __restrict__ knew_ptr; void * __restrict__ vnew_ptr; // The stride between rows of the Q, K and V matrices. index_t knew_batch_stride; index_t vnew_batch_stride; index_t knew_row_stride; index_t vnew_row_stride; index_t knew_head_stride; index_t vnew_head_stride; void *__restrict__ qv_ptr; index_t qv_batch_stride; index_t qv_row_stride; index_t qv_head_stride; // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; int *__restrict__ seqlens_rotary; // The indices to index into the KV cache. int * __restrict__ kv_batch_idx; // Paged KV cache int * __restrict__ page_table; index_t page_table_batch_stride; int page_size; int num_pages; bool pagedkv_tma; // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; // uint16_t p_dropout_in_uint16_t; uint8_t p_dropout_in_uint8_t; // Scale factor of 1 / (1 - p_dropout). float rp_dropout; // Local window size int window_size_left, window_size_right; int attention_chunk; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t * rng_state; bool is_bf16; bool is_fp32; bool is_e4m3; bool is_causal; bool is_local; bool is_rotary_interleaved; int num_splits; // For split-KV version bool pack_gqa; int * __restrict__ tile_count_semaphore; int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual int * __restrict__ num_nheads_in_l2_ptr; bool skip_scheduler_metadata_computation; bool varlen_sort_batches; int tile_count_semaphore_offset; bool head_swizzle; bool prepare_varlen_pdl; int arch; int num_sm; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_bwd_params : public Flash_fwd_params { using index_t = int64_t; // The dO and dQKV matrices. void *__restrict__ do_ptr; void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ // dv_accum_ptr; // The stride between rows of the dO, dQ, dK and dV matrices. index_t do_batch_stride; index_t do_row_stride; index_t do_head_stride; index_t dq_batch_stride; index_t dk_batch_stride; index_t dv_batch_stride; index_t dq_row_stride; index_t dk_row_stride; index_t dv_row_stride; index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; void *__restrict__ softmax_lse_log2_ptr; int *__restrict__ dq_semaphore; int *__restrict__ dk_semaphore; int *__restrict__ dv_semaphore; bool deterministic; index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); ================================================ FILE: hopper/flash_api.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #include #include #include #include #include #include "flash.h" #include "static_switch.h" #include "tile_size.h" #include "heuristics.h" #include "cuda_check.h" extern "C" { /* Creates a dummy empty _C module that can be imported from Python. The import from Python will load the .so consisting of this file in this extension, so that the TORCH_LIBRARY static initializers below are run. */ PyObject* PyInit__C(void) { static struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, "_C", /* name of module */ NULL, /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ NULL, /* methods */ }; return PyModule_Create(&module_def); } } #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 namespace { inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor& t) { return at::cuda::CUDAGuard(static_cast(t.get_device())); } } // namespace void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *softmax_lse_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, int attention_chunk, const float softcap=0.f, const int sm_margin=0) { // Reset the parameters params = {}; params.is_bf16 = q.dtype() == torch::kBFloat16; params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); params.v_dim_stride = v.stride(-1); params.o_ptr = out.data_ptr(); params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.o_batch_stride = out.stride(0); } if (cu_seqlens_k_d == nullptr) { params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); params.seqused_q = static_cast(seqused_q); params.seqused_k = static_cast(seqused_k); // Softmax sum params.softmax_lse_ptr = softmax_lse_d; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; // Set the different scale values. params.scale_softmax = softmax_scale; params.softcap = softcap; // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; // Convert p from float to int so we don't have to convert the random uint to float to compare. // [Minor] We want to round down since when we do the comparison we use <= instead of < // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; TORCH_CHECK(p_dropout < 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; // TODO: check this if (window_size_left < 0) { window_size_left = seqlen_k - 1; } if (window_size_right < 0) { window_size_right = seqlen_q - 1; } if (attention_chunk > 0) { window_size_left = std::min(window_size_left, attention_chunk - 1); window_size_right = std::min(window_size_right, attention_chunk - 1); } params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif } void set_params_dgrad(Flash_bwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const at::Tensor q, const at::Tensor k, const at::Tensor v, const at::Tensor out, const at::Tensor dout, at::Tensor dq, at::Tensor dk, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, void *softmax_lse_d, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, int attention_chunk, const float softcap=0.f, bool deterministic=false, int const sm_margin=0) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, seqused_q, seqused_k, softmax_lse_d, p_dropout, softmax_scale, window_size_left, window_size_right, attention_chunk, softcap, sm_margin); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); params.do_row_stride = dout.stride(-3); params.do_head_stride = dout.stride(-2); params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); params.dq_row_stride = dq.stride(-3); params.dk_row_stride = dk.stride(-3); params.dv_row_stride = dv.stride(-3); params.dq_head_stride = dq.stride(-2); params.dk_head_stride = dk.stride(-2); params.dv_head_stride = dv.stride(-2); if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); } params.dq_accum_ptr = dq_accum_d; params.dk_accum_ptr = dk_accum_d; params.dv_accum_ptr = dv_accum_d; // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; params.deterministic = deterministic; } template void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); } else if (params.dv > 64) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); } else if (params.dv > 64) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); #endif } } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // HEADDIM_SWITCH(params.d, [&] { // run_mha_fwd_(params, stream); // }); TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { run_mha_fwd_constexpr(params, stream); }); }); }); }); }); } void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { #ifndef FLASHATTENTION_DISABLE_SPLIT // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } else if (params.is_bf16) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } else { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } #else TORCH_CHECK(false, "This flash attention build does not support combine kernels."); #endif } inline bool get_pagedkv_tma(Flash_fwd_params const& params) { if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } // This needs to match the kernel configs auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, // at least for MLA. return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; } inline bool get_pack_gqa(Flash_fwd_params const& params) { // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif } inline int get_num_splits(Flash_fwd_params const& params) { #ifdef FLASHATTENTION_DISABLE_SPLIT return 1; #else // Always enable PackGQA for Split // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); // If is_local, we're not going to load all of seqlen_k int const seqlen_k_loaded = !params.is_local ? params.seqlen_k : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); // Always enable PackGQA for Split // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending // that batch = 1. int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } inline int get_max_headdim() { #ifndef FLASHATTENTION_DISABLE_HDIM256 return 256; #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 return 192; #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 return 128; #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 return 96; #endif #ifndef FLASHATTENTION_DISABLE_HDIM64 return 64; #endif return 0; } inline int round_up_headdim(int head_size) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (head_size <= 64) { return 64; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (head_size <= 96) { return 96; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (head_size <= 128) { return 128; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (head_size <= 192) { return 192; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (head_size <= 256) { return 256; } #endif return 256; } inline int round_up_headdimv(int head_size) { if (head_size <= 64) { return 64; } if (head_size <= 96) { return 96; } if (head_size <= 128) { return 128; } if (head_size <= 192) { return 192; } if (head_size <= 256) { return 256; } return 512; } // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available at::Tensor mha_fwd_get_scheduler_metadata( int64_t batch_size, int64_t max_seqlen_q, int64_t max_seqlen_k, int64_t num_heads, int64_t num_heads_k, int64_t headdim, int64_t headdim_v, at::ScalarType qkv_dtype, at::Tensor seqused_k, // b std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional cu_seqlens_k_new_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional leftpad_k_, // b std::optional page_size, int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, int64_t window_size_left, int64_t window_size_right, int64_t attention_chunk, bool has_softcap, int64_t num_splits, std::optional pack_gqa_, int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // Reset the parameters Flash_fwd_params params{}; params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16; params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn; params.b = batch_size; params.seqlen_q = max_seqlen_q; params.seqlen_k = max_seqlen_k; params.h = num_heads; params.h_k = num_heads_k; params.d = headdim; params.dv = headdim_v; params.d_rounded = round_up_headdim(headdim); params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); params.seqlen_knew = max_seqlen_k_new; bool const is_varlen_q = cu_seqlens_q_.has_value(); params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr() : nullptr; bool const is_varlen_k = cu_seqlens_k_.has_value(); params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr() : nullptr; params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr() : nullptr; params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr; params.seqused_k = seqused_k.data_ptr(); params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr() : nullptr; params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { is_causal = false; } } if (is_causal) { window_size_right = 0; } params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } if (attention_chunk > 0) { window_size_left = std::min(window_size_left, attention_chunk - 1); window_size_right = std::min(window_size_right, attention_chunk - 1); } params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin; params.softcap = has_softcap ? 1.0f : 0.0f; params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); bool const use_prepare_varlen = true; params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); bool is_varlen = true; // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_cuda_guard_from_tensor(seqused_k); auto opts = seqused_k.options(); // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::empty( {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, opts.dtype(torch::kInt32)); // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } } if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); auto stream = at::cuda::getCurrentCUDAStream().stream(); prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } return tile_count_semaphore; } // b: batch_size // b_k: batch_size_k // s_q: seqlen_q // s_k: seqlen_k // s_k_new: seqlen_k_new // h: num_heads // h_k: num_heads_k // d: head_size std::tuple mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional cu_seqlens_k_new_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k std::optional max_seqlen_k_, std::optional page_table_, // (b_k, max_num_pages_per_seq) std::optional kv_batch_idx_, // b. indices to index into the KV cache std::optional leftpad_k_, // b std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional seqlens_rotary_, // b std::optional q_descale_, // (b, h_k), not (b, h) std::optional k_descale_, // (b, h_k) std::optional v_descale_, // (b, h_k) std::optional softmax_scale_, bool is_causal, int64_t window_size_left, int64_t window_size_right, int64_t attention_chunk, double softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional scheduler_metadata_, // (b + 1) int64_t num_splits, std::optional pack_gqa_, int64_t sm_margin ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major >= 8; TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.scalar_type(); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); if (dprops->major < 9) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); } TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); at::Tensor page_table; const bool paged_KV = page_table_.has_value(); if (paged_KV) { page_table = page_table_.value(); CHECK_DEVICE(page_table); TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); } at::Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } at::Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } auto const sizes = q.sizes(); const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); double softmax_scale = 1.0 / sqrt(double(head_size)); if (softmax_scale_.has_value()) { softmax_scale = softmax_scale_.value(); } if (!kv_batch_idx_.has_value()) { TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " "or (Q/K <= 64 and V <= 512)."); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "HeaddimV > 256 requires fp16 and bf16 data type"); } } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((head_size <= 64 || head_size > 128) || !paged_KV) { is_causal = false; } } if (is_causal) { window_size_right = 0; } if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } if (seqused_q_.has_value()){ auto seqused_q = seqused_q_.value(); TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); CHECK_SHAPE(seqused_q, batch_size); } if (seqused_k_.has_value()) { auto seqused_k = seqused_k_.value(); TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); } // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { out = !is_varlen_q ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_cuda_guard_from_tensor(q); at::Tensor softmax_lse; if (!is_varlen_q) { softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); } else { softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); } Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, attention_chunk, softcap, sm_margin); params.total_q = total_q; params.total_k = total_k; params.b_k = batch_size_k; params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); } if (paged_KV) { params.page_table = page_table.data_ptr(); params.page_table_batch_stride = page_table.stride(0); } params.page_size = page_size; params.num_pages = num_pages; if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma at::Tensor k_new, v_new; TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); at::Tensor cu_seqlens_k_new; bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); if (is_varlen_k_new) { cu_seqlens_k_new = cu_seqlens_k_new_.value(); CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32"); } k_new = k_new_.value(); v_new = v_new_.value(); TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query"); TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query"); CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; params.total_knew = total_k_new; params.knew_ptr = k_new.data_ptr(); params.vnew_ptr = v_new.data_ptr(); // All stride are in elements, not bytes. params.knew_row_stride = k_new.stride(-3); params.vnew_row_stride = v_new.stride(-3); params.knew_head_stride = k_new.stride(-2); params.vnew_head_stride = v_new.stride(-2); if (!is_varlen_k_new) { params.knew_batch_stride = k_new.stride(0); params.vnew_batch_stride = v_new.stride(0); } if (is_varlen_k_new) { params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } bool const use_prepare_varlen = is_varlen; params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic // We don't use the persistent scheduler if Split and not Varlen bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); CHECK_DEVICE(scheduler_metadata); CHECK_SHAPE(scheduler_metadata, metadata_size); CHECK_CONTIGUOUS(scheduler_metadata); TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32"); tile_count_semaphore = scheduler_metadata; } else { tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); at::Tensor q_v = q_v_.value(); TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); CHECK_DEVICE(q_v); TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); } params.qv_ptr = q_v.data_ptr(); // All stride are in elements, not bytes. params.qv_row_stride = q_v.stride(-3); params.qv_head_stride = q_v.stride(-2); if (!is_varlen_q) { params.qv_batch_stride = q_v.stride(0); } } if (rotary_cos_.has_value()) { TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2; TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); if (paged_KV) { TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); } CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; if (seqlens_rotary_.has_value()) { at::Tensor seqlens_rotary = seqlens_rotary_.value(); CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); CHECK_SHAPE(seqlens_rotary, batch_size); params.seqlens_rotary = seqlens_rotary.data_ptr(); } } else { params.rotary_dim = 0; } if (kv_batch_idx_.has_value()) { auto kv_batch_idx = kv_batch_idx_.value(); CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); } at::Tensor out_accum, softmax_lse_accum; auto outaccum_type = at::ScalarType::Float; if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; params.oaccum_ptr = out_accum.data_ptr(); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_split_stride = out_accum.stride(0); params.oaccum_row_stride = out_accum.stride(-2); params.oaccum_head_stride = out_accum.stride(-3); params.lseaccum_split_stride = softmax_lse_accum.stride(0); params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } if (q_type == at::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); CHECK_DEVICE(q_descale); CHECK_SHAPE(q_descale, batch_size, num_heads_k); params.q_descale_ptr = q_descale.data_ptr(); params.q_descale_batch_stride = q_descale.stride(0); params.q_descale_head_stride = q_descale.stride(1); } else { params.q_descale_ptr = nullptr; } if (k_descale_.has_value()) { auto k_descale = k_descale_.value(); CHECK_DEVICE(k_descale); CHECK_SHAPE(k_descale, batch_size, num_heads_k); params.k_descale_ptr = k_descale.data_ptr(); params.k_descale_batch_stride = k_descale.stride(0); params.k_descale_head_stride = k_descale.stride(1); } else { params.k_descale_ptr = nullptr; } if (v_descale_.has_value()) { auto v_descale = v_descale_.value(); CHECK_DEVICE(v_descale); CHECK_SHAPE(v_descale, batch_size, num_heads_k); params.v_descale_ptr = v_descale.data_ptr(); params.v_descale_batch_stride = v_descale.stride(0); params.v_descale_head_stride = v_descale.stride(1); } else { params.v_descale_ptr = nullptr; } } #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif #ifdef FLASHATTENTION_DISABLE_SPLIT TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); #endif if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); if (params.num_splits > 1) { if (out_type == at::ScalarType::BFloat16) { // Since we want output in BF16. Otherwise fwd_combine will output to FP16 params.is_bf16 = true; } // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 // and seqlen = total_q, and don't need to dispatch to Varlen there. // However, with dynamic split, each row needs to know which batch it belongs to // to read the number of splits, so we just use the varlen version of combine kernel. // if (is_varlen_q && !seqused_q_.has_value()) { // if (is_varlen_q) { // params.b = 1; // params.seqlen_q = total_q; // } // This will zero out the semaphore if needed run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } // return {out, softmax_lse}; return {out, softmax_lse, out_accum, softmax_lse_accum}; } #ifdef FLASHATTENTION_DISABLE_BACKWARD void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); } #else template void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { if (!params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } else { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } #endif } } void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // FP16_SWITCH(!params.is_bf16, [&] { // HEADDIM_SWITCH(params.d, [&] { // run_mha_bwd_(params, stream); // }); // }); ARCH_SWITCH(params.arch, Arch, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { run_mha_bwd_constexpr(params, stream); }); }); } #endif // b: batch_size // s_q: seqlen_q // s_k: seqlen_k // h: num_heads // h_k: num_heads_k // d: head_size std::tuple mha_bwd( at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, std::optional max_seqlen_k_, std::optional softmax_scale_, bool is_causal, int64_t window_size_left, int64_t window_size_right, double softcap, bool deterministic, int64_t sm_margin ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); #endif auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major >= 8; TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.dtype(); TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype"); TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); at::Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } at::Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); } // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif auto const sizes = q.sizes(); int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1; int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value(); int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int const num_heads = q.size(-2); int const head_size = q.size(-1); int const head_size_v = v.size(-1); int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); int const max_headdim = get_max_headdim(); TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); double softmax_scale = 1.0 / sqrt(double(head_size)); if (softmax_scale_.has_value()) { softmax_scale = softmax_scale_.value(); } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } if (is_causal) { window_size_right = 0; } // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). is_causal = window_size_left < 0 && window_size_right == 0; int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) : (head_size_rounded <= 96 ? 64 : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) : 64)); int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); int const kBlockN_sm90 = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 96 : 80); int const kBlockN_sm80 = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64); int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 : (head_size_rounded <= 96 ? 128 : (head_size_rounded <= 128 ? 96 : (head_size_rounded <= 192 ? 64 : 64))); int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(out, total_q, num_heads, head_size_v); CHECK_SHAPE(dout, total_q, num_heads, head_size_v); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!is_varlen_k) { CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } if (seqused_q_.has_value()){ auto seqused_q = seqused_q_.value(); TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32"); CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); CHECK_SHAPE(seqused_q, batch_size); } if (seqused_k_.has_value()){ auto seqused_k = seqused_k_.value(); TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32"); CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } at::Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q"); CHECK_DEVICE(dq); TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { CHECK_SHAPE(dq, total_q, num_heads, head_size); } } else { dq = torch::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q"); CHECK_DEVICE(dk); TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); if (!is_varlen_k) { CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } } else { dk = torch::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q"); CHECK_DEVICE(dv); TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); if (!is_varlen_k) { CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); } } else { dv = torch::empty_like(v); } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_cuda_guard_from_tensor(q); auto opts = q.options(); // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 at::Tensor softmax_d, softmax_lse_log2; if (!is_varlen) { // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); } else { softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat)); } at::Tensor dq_accum, dk_accum, dv_accum; if (!is_varlen) { dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat)); } else { dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat)); } if (num_heads_k != num_heads) { // MQA / GQA if (!is_varlen) { dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat)); dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat)); } else { dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat)); dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat)); } } Flash_bwd_params params; set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, dout, dq, dk, dv, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, dq_accum.data_ptr(), num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, softmax_lse.data_ptr(), softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, 0, // attention_chunk softcap, deterministic, sm_margin); params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); // params.tile_count_semaphore = tile_count_semaphore.data_ptr(); // Will be zero'ed out in the backward preprocess kernel at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); params.dq_semaphore = dq_semaphore.data_ptr(); at::Tensor dk_semaphore, dv_semaphore; if (num_heads_k != num_heads && params.deterministic) { // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } #ifdef FLASHATTENTION_DISABLE_LOCAL TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif if (total_q > 0 && total_k > 0 && num_heads_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_bwd(params, stream); } else if (total_k > 0 && num_heads_k > 0) { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dk.zero_(); dv.zero_(); softmax_d.zero_(); } else if (total_q > 0 && num_heads_k > 0) { dq.zero_(); softmax_d.zero_(); } return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } std::tuple mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads std::optional out_, // batch_size x seqlen x num_heads x head_size std::optional out_dtype_ ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major >= 8; TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); auto out_partial_type = out_partial.scalar_type(); TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type"); TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type"); CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); const auto sizes = out_partial.sizes(); const int num_splits = sizes[0]; const int batch_size = sizes[1]; const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); int const alignment = 4; at::Tensor out_partial_padded; auto pad = [](at::Tensor x, int alignment) { return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment})); }; out_partial_padded = pad(out_partial, alignment); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, alignment); auto opts = out_partial.options(); at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); at::Tensor out; if (out_.has_value()) { out = out_.value(); TORCH_CHECK(out.scalar_type() == out_type); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); if (head_size_og % alignment != 0) { out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); } } else { out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type)); } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()}; auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2); Flash_fwd_params params {}; // Need to reset the params to set everything to zero params.is_fp32 = out_type == at::ScalarType::Float; params.is_bf16 = out_type == at::ScalarType::BFloat16; params.oaccum_ptr = out_partial_padded.data_ptr(); params.softmax_lseaccum_ptr = lse_partial.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); params.oaccum_head_stride = out_partial_padded.stride(3); params.oaccum_batch_stride = out_partial_padded.stride(1); params.lseaccum_split_stride = lse_partial.stride(0); params.lseaccum_head_stride = lse_partial.stride(3); params.lseaccum_batch_stride = lse_partial.stride(1); params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor; if (seqlen > 0 && batch_size > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } at::Tensor out_padded = out; if (head_size_og % alignment != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); // if (out_.has_value()) { out_.value().copy_(out); } } return {out, softmax_lse}; } TORCH_LIBRARY(flash_attn_3, m) { m.def("fwd(" "Tensor q," "Tensor k," "Tensor v," "Tensor(k_new!)? k_new = None," "Tensor(v_new!)? v_new = None," "Tensor? q_v = None," "Tensor(out!)? out = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "Tensor? page_table = None," "Tensor? kv_batch_idx = None," "Tensor? leftpad_k = None," "Tensor? rotary_cos = None," "Tensor? rotary_sin = None," "Tensor? seqlens_rotary = None," "Tensor? q_descale = None," "Tensor? k_descale = None," "Tensor? v_descale = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "float softcap = 0.0," "bool is_rotary_interleaved = False," "Tensor? scheduler_metadata = None," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); m.def("bwd(" "Tensor dout," "Tensor q," "Tensor k," "Tensor v," "Tensor out," "Tensor softmax_lse," "Tensor(dq!)? dq = None," "Tensor(dk!)? dk = None," "Tensor(dv!)? dv = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," "Tensor(out!)? out = None," "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); m.def("get_scheduler_metadata(" "int batch_size," "int max_seqlen_q," "int max_seqlen_k," "int num_heads," "int num_heads_k," "int headdim," "int headdim_v," "ScalarType qkv_dtype," "Tensor seqused_k," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? leftpad_k = None," "int? page_size = None," "int max_seqlen_k_new = 0," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "bool has_softcap = False," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> Tensor"); } TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { m.impl("fwd", &mha_fwd); m.impl("bwd", &mha_bwd); m.impl("fwd_combine", &mha_combine); m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata); } ================================================ FILE: hopper/flash_api_stable.cpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #include #include #include "flash.h" #include "static_switch.h" #include "tile_size.h" #include "heuristics.h" #include "cuda_check.h" #include #include #include #include #include // Declare the CUDA stream function that's behind #ifdef USE_CUDA in shim.h extern "C" AOTITorchError aoti_torch_get_current_cuda_stream(int32_t device_index, void** ret_stream); #include #include #include #include #include #include using torch::stable::Tensor; namespace tsa = torch::stable::accelerator; namespace { inline tsa::DeviceGuard make_device_guard(const Tensor& t) { return tsa::DeviceGuard(static_cast(t.get_device())); } std::deque device_flags; std::vector device_properties; void initVectors() { static bool init_flag [[maybe_unused]] = []() { int device_count; cudaError_t err = cudaGetDeviceCount(&device_count); if (err != cudaSuccess) { STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + std::string(cudaGetErrorString(err))); } device_flags.resize(device_count); device_properties.resize(device_count); return true; }(); } void initDeviceProperty(int device_index) { cudaDeviceProp device_prop{}; cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index); if (err != cudaSuccess) { STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " + std::string(cudaGetErrorString(err))); } device_properties[device_index] = device_prop; } // Helper function to get device properties using raw CUDA APIs cudaDeviceProp* get_device_prop() { initVectors(); int device_index; cudaError_t err = cudaGetDevice(&device_index); if (err != cudaSuccess) { STD_TORCH_CHECK(false, "cudaGetDevice failed: " + std::string(cudaGetErrorString(err))); } std::call_once(device_flags[device_index], initDeviceProperty, device_index); return &device_properties[device_index]; } } // anonymous namespace extern "C" { /* Creates a dummy empty _C module that can be imported from Python. The import from Python will load the .so consisting of this file in this extension, so that the STABLE_TORCH_LIBRARY static initializers below are run. */ PyObject* PyInit__C(void) { static struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, "_C", /* name of module */ NULL, /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ NULL, /* methods */ }; return PyModule_Create(&module_def); } } #define CHECK_DEVICE(x) STD_TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_SHAPE(x, ...) \ do { \ auto expected_dims = std::vector{__VA_ARGS__}; \ STD_TORCH_CHECK(x.dim() == static_cast(expected_dims.size()), #x " must have " + std::to_string(expected_dims.size()) + " dimensions, got " + std::to_string(x.dim())); \ for (size_t i = 0; i < expected_dims.size(); ++i) { \ STD_TORCH_CHECK(x.size(i) == expected_dims[i], #x " dimension " + std::to_string(i) + " must have size " + std::to_string(expected_dims[i]) + ", got " + std::to_string(x.size(i))); \ } \ } while (0) #define CHECK_CONTIGUOUS(x) STD_TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const Tensor q, const Tensor k, const Tensor v, Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *softmax_lse_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, int attention_chunk, const float softcap=0.f, const int sm_margin=0) { // Reset the parameters params = {}; params.is_bf16 = q.scalar_type() == torch::headeronly::ScalarType::BFloat16; params.is_e4m3 = q.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn; // Set the pointers and strides. params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); params.v_dim_stride = v.stride(-1); params.o_ptr = out.data_ptr(); params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.o_batch_stride = out.stride(0); } if (cu_seqlens_k_d == nullptr) { params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); params.seqused_q = static_cast(seqused_q); params.seqused_k = static_cast(seqused_k); // Softmax sum params.softmax_lse_ptr = softmax_lse_d; // Set the dimensions. params.b = b; params.h = h; params.h_k = h_k; params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.seqlen_q_rounded = seqlen_q_rounded; params.seqlen_k_rounded = seqlen_k_rounded; params.d = d; params.d_rounded = d_rounded; // Set the different scale values. params.scale_softmax = softmax_scale; params.softcap = softcap; // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; // Convert p from float to int so we don't have to convert the random uint to float to compare. // [Minor] We want to round down since when we do the comparison we use <= instead of < // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; STD_TORCH_CHECK(p_dropout < 1.f); #ifdef FLASHATTENTION_DISABLE_DROPOUT STD_TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; // TODO: check this if (window_size_left < 0) { window_size_left = seqlen_k - 1; } if (window_size_right < 0) { window_size_right = seqlen_q - 1; } if (attention_chunk > 0) { window_size_left = std::min(window_size_left, attention_chunk - 1); window_size_right = std::min(window_size_right, attention_chunk - 1); } params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; auto dprops = get_device_prop(); params.arch = dprops->major * 10 + dprops->minor; params.num_sm = dprops->multiProcessorCount - sm_margin; #ifdef FLASHATTENTION_DISABLE_LOCAL STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif } void set_params_dgrad(Flash_bwd_params ¶ms, // sizes const size_t b, const size_t seqlen_q, const size_t seqlen_k, const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, // device pointers const Tensor q, const Tensor k, const Tensor v, const Tensor out, const Tensor dout, Tensor dq, Tensor dk, Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_q, void *seqused_k, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, void *softmax_lse_d, void *dsoftmax_sum_d, float p_dropout, float softmax_scale, int window_size_left, int window_size_right, int attention_chunk, const float softcap=0.f, bool deterministic=false, int const sm_margin=0) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, seqused_q, seqused_k, softmax_lse_d, p_dropout, softmax_scale, window_size_left, window_size_right, attention_chunk, softcap, sm_margin); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); params.do_row_stride = dout.stride(-3); params.do_head_stride = dout.stride(-2); params.dq_ptr = dq.data_ptr(); params.dk_ptr = dk.data_ptr(); params.dv_ptr = dv.data_ptr(); params.dq_row_stride = dq.stride(-3); params.dk_row_stride = dk.stride(-3); params.dv_row_stride = dv.stride(-3); params.dq_head_stride = dq.stride(-2); params.dk_head_stride = dk.stride(-2); params.dv_head_stride = dv.stride(-2); if (cu_seqlens_q_d == nullptr) { params.do_batch_stride = dout.stride(0); params.dq_batch_stride = dq.stride(0); params.dk_batch_stride = dk.stride(0); params.dv_batch_stride = dv.stride(0); } params.dq_accum_ptr = dq_accum_d; params.dk_accum_ptr = dk_accum_d; params.dv_accum_ptr = dv_accum_d; // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; params.deterministic = deterministic; } template void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); } else if (params.dv > 64) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); } else if (params.dv > 64) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } #endif return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif #else STD_TORCH_CHECK(false, "This flash attention build does not support FP8."); #endif } } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // HEADDIM_SWITCH(params.d, [&] { // run_mha_fwd_(params, stream); // }); STD_TORCH_CHECK(params.num_splits >= 1); ARCH_SWITCH(params.arch, Arch, [&] { SPLIT_SWITCH(params.num_splits > 1, Split, [&] { PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] { PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] { // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split; SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] { run_mha_fwd_constexpr(params, stream); }); }); }); }); }); } void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) { #ifndef FLASHATTENTION_DISABLE_SPLIT // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } else if (params.is_bf16) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } else { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream, enable_pdl); } else { run_mha_fwd_combine_(params, stream, enable_pdl); } } #else STD_TORCH_CHECK(false, "This flash attention build does not support combine kernels."); #endif } inline bool get_pagedkv_tma(Flash_fwd_params const& params) { if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } // This needs to match the kernel configs auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, // at least for MLA. return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM; } inline bool get_pack_gqa(Flash_fwd_params const& params) { // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif } inline int get_num_splits(Flash_fwd_params const& params) { #ifdef FLASHATTENTION_DISABLE_SPLIT return 1; #else // Always enable PackGQA for Split // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); // If is_local, we're not going to load all of seqlen_k int const seqlen_k_loaded = !params.is_local ? params.seqlen_k : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2); // Always enable PackGQA for Split // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits. // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending // that batch = 1. int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } inline int get_max_headdim() { #ifndef FLASHATTENTION_DISABLE_HDIM256 return 256; #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 return 192; #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 return 128; #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 return 96; #endif #ifndef FLASHATTENTION_DISABLE_HDIM64 return 64; #endif return 0; } inline int round_up_headdim(int head_size) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (head_size <= 64) { return 64; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (head_size <= 96) { return 96; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (head_size <= 128) { return 128; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (head_size <= 192) { return 192; } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (head_size <= 256) { return 256; } #endif return 256; } inline int round_up_headdimv(int head_size) { if (head_size <= 64) { return 64; } if (head_size <= 96) { return 96; } if (head_size <= 128) { return 128; } if (head_size <= 192) { return 192; } if (head_size <= 256) { return 256; } return 512; } // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available Tensor mha_fwd_get_scheduler_metadata( int64_t batch_size, int64_t max_seqlen_q, int64_t max_seqlen_k, int64_t num_heads, int64_t num_heads_k, int64_t headdim, int64_t headdim_v, torch::headeronly::ScalarType qkv_dtype, Tensor seqused_k, // b std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional cu_seqlens_k_new_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional leftpad_k_, // b std::optional page_size, int64_t max_seqlen_k_new, // 0 means we're not appending new KV bool is_causal, int64_t window_size_left, int64_t window_size_right, int64_t attention_chunk, bool has_softcap, int64_t num_splits, std::optional pack_gqa_, int64_t sm_margin) { STD_TORCH_CHECK(qkv_dtype == torch::headeronly::ScalarType::Half || qkv_dtype == torch::headeronly::ScalarType::BFloat16 || qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); // Reset the parameters Flash_fwd_params params{}; params.is_bf16 = qkv_dtype == torch::headeronly::ScalarType::BFloat16; params.is_e4m3 = qkv_dtype == torch::headeronly::ScalarType::Float8_e4m3fn; params.b = batch_size; params.seqlen_q = max_seqlen_q; params.seqlen_k = max_seqlen_k; params.h = num_heads; params.h_k = num_heads_k; params.d = headdim; params.dv = headdim_v; params.d_rounded = round_up_headdim(headdim); params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v); params.seqlen_knew = max_seqlen_k_new; bool const is_varlen_q = cu_seqlens_q_.has_value(); params.cu_seqlens_q = is_varlen_q ? static_cast(cu_seqlens_q_.value().data_ptr()) : nullptr; bool const is_varlen_k = cu_seqlens_k_.has_value(); params.cu_seqlens_k = is_varlen_k ? static_cast(cu_seqlens_k_.value().data_ptr()) : nullptr; params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? static_cast(cu_seqlens_k_new_.value().data_ptr()): nullptr; params.seqused_q = seqused_q_.has_value() ? static_cast(seqused_q_.value().data_ptr()) : nullptr; params.seqused_k = static_cast(seqused_k.data_ptr()); params.leftpad_k = leftpad_k_.has_value() ? static_cast(leftpad_k_.value().data_ptr()) : nullptr; params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast(1) : nullptr; if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) { is_causal = false; } } if (is_causal) { window_size_right = 0; } params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0; params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal; if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; } if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; } if (attention_chunk > 0) { window_size_left = std::min(window_size_left, attention_chunk - 1); window_size_right = std::min(window_size_right, attention_chunk - 1); } params.window_size_left = window_size_left; params.window_size_right = window_size_right; params.attention_chunk = attention_chunk; auto dprops = get_device_prop(); params.arch = dprops->major * 10 + dprops->minor; params.num_sm = dprops->multiProcessorCount - sm_margin; params.softcap = has_softcap ? 1.0f : 0.0f; params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); bool const use_prepare_varlen = true; params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); bool is_varlen = true; // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_device_guard(seqused_k); // This needs to be set after get_num_splits Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); tile_count_semaphore = torch::stable::new_empty( seqused_k, {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, std::make_optional(torch::headeronly::ScalarType::Int)); // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { if (!use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); } // If varlen we'll manually do the zero-ing params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } } if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); cudaStream_t stream = static_cast(stream_ptr); prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } return tile_count_semaphore; } // b: batch_size // b_k: batch_size_k // s_q: seqlen_q // s_k: seqlen_k // s_k_new: seqlen_k_new // h: num_heads // h_k: num_heads_k // d: head_size std::tuple mha_fwd(Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional cu_seqlens_k_new_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k std::optional max_seqlen_k_, std::optional page_table_, // (b_k, max_num_pages_per_seq) std::optional kv_batch_idx_, // b. indices to index into the KV cache std::optional leftpad_k_, // b std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) std::optional seqlens_rotary_, // b std::optional q_descale_, // (b, h_k), not (b, h) std::optional k_descale_, // (b, h_k) std::optional v_descale_, // (b, h_k) std::optional softmax_scale_, bool is_causal, int64_t window_size_left, int64_t window_size_right, int64_t attention_chunk, double softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional scheduler_metadata_, // (b + 1) int64_t num_splits, std::optional pack_gqa_, int64_t sm_margin ) { auto dprops = get_device_prop(); bool is_sm8x = dprops->major >= 8; STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.scalar_type(); STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16 || q_type == torch::headeronly::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); if (dprops->major < 9) { STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); } STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); Tensor page_table; const bool paged_KV = page_table_.has_value(); if (paged_KV) { page_table = page_table_.value(); CHECK_DEVICE(page_table); STD_TORCH_CHECK(page_table.scalar_type() == torch::headeronly::ScalarType::Int, "page_table must have dtype torch.int32"); STD_TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); } Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); STD_TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported"); STD_TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported"); } const int batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; int seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); int total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); int num_heads = q.size(-2); int const head_size = q.size(-1); int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0); double softmax_scale = 1.0 / sqrt(double(head_size)); if (softmax_scale_.has_value()) { softmax_scale = softmax_scale_.value(); } if (!kv_batch_idx_.has_value()) { STD_TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } int const max_headdim = get_max_headdim(); STD_TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); if (head_size_v != head_size) { STD_TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) || (head_size <= 64 && head_size_v <= 512), "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], " "or (Q/K <= 64 and V <= 512)."); STD_TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); if (head_size_v > 256) { STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, "HeaddimV > 256 requires fp16 and bf16 data type"); } } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } // causal=true is the same as causal=false in this case if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) { // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA if ((head_size <= 64 || head_size > 128) || !paged_KV) { is_causal = false; } } if (is_causal) { window_size_right = 0; } if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } if (seqused_q_.has_value()){ auto seqused_q = seqused_q_.value(); STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); CHECK_SHAPE(seqused_q, batch_size); } if (seqused_k_.has_value()) { auto seqused_k = seqused_k_.value(); STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); STD_TORCH_CHECK(leftpad_k.scalar_type() == torch::headeronly::ScalarType::Int, "leftpad_k must have dtype int32"); CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k); CHECK_SHAPE(leftpad_k, batch_size); } // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif int const alignment = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? 16 : 8; STD_TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); STD_TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto out_type = q_type == torch::headeronly::ScalarType::Float8_e4m3fn ? torch::headeronly::ScalarType::BFloat16 : q_type; Tensor out; if (out_.has_value()) { out = out_.value(); STD_TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16"); CHECK_DEVICE(out); STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { out = !is_varlen_q ? torch::stable::new_empty(q, {batch_size, seqlen_q, num_heads, head_size_v}, std::make_optional(out_type)) : torch::stable::new_empty(q, {total_q, num_heads, head_size_v}, std::make_optional(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_device_guard(q); Tensor softmax_lse; if (!is_varlen_q) { softmax_lse = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); } else { softmax_lse = torch::stable::new_empty(q, {num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); } Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, attention_chunk, softcap, sm_margin); params.total_q = total_q; params.total_k = total_k; params.b_k = batch_size_k; params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma params.leftpad_k = static_cast(leftpad_k_.value().data_ptr()); } if (paged_KV) { params.page_table = static_cast(page_table.data_ptr()); params.page_table_batch_stride = page_table.stride(0); } params.page_size = page_size; params.num_pages = num_pages; if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma Tensor k_new, v_new; STD_TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in"); STD_TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in"); STD_TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache"); Tensor cu_seqlens_k_new; bool const is_varlen_k_new = cu_seqlens_k_new_.has_value(); if (is_varlen_k_new) { cu_seqlens_k_new = cu_seqlens_k_new_.value(); CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new); STD_TORCH_CHECK(cu_seqlens_k_new.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k_new must have dtype torch.int32"); } k_new = k_new_.value(); v_new = v_new_.value(); STD_TORCH_CHECK(k_new.scalar_type() == q_type, "k_new must have the same dtype as query"); STD_TORCH_CHECK(v_new.scalar_type() == q_type, "v_new must have the same dtype as query"); CHECK_DEVICE(k_new); CHECK_DEVICE(v_new); STD_TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension"); STD_TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension"); // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0; int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; params.total_knew = total_k_new; params.knew_ptr = k_new.data_ptr(); params.vnew_ptr = v_new.data_ptr(); // All stride are in elements, not bytes. params.knew_row_stride = k_new.stride(-3); params.vnew_row_stride = v_new.stride(-3); params.knew_head_stride = k_new.stride(-2); params.vnew_head_stride = v_new.stride(-2); if (!is_varlen_k_new) { params.knew_batch_stride = k_new.stride(0); params.vnew_batch_stride = v_new.stride(0); } if (is_varlen_k_new) { params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } bool const use_prepare_varlen = is_varlen; params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); // This needs to be set after get_num_splits Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic // We don't use the persistent scheduler if Split and not Varlen bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template if (scheduler_needs_semaphore || use_prepare_varlen) { int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } if(params.head_swizzle) { num_prepare_batch_vectors += 1; } int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { Tensor scheduler_metadata = scheduler_metadata_.value(); CHECK_DEVICE(scheduler_metadata); CHECK_SHAPE(scheduler_metadata, metadata_size); CHECK_CONTIGUOUS(scheduler_metadata); STD_TORCH_CHECK(scheduler_metadata.scalar_type() == torch::headeronly::ScalarType::Int, "scheduler_metadata must have dtype int32"); tile_count_semaphore = scheduler_metadata; } else { tile_count_semaphore = torch::stable::new_empty(q, {metadata_size}, torch::headeronly::ScalarType::Int); } if (scheduler_needs_semaphore && !use_prepare_varlen) { torch::stable::zero_(tile_count_semaphore); // If varlen we'll manually do the zero-ing } // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} params.num_splits_dynamic_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) : nullptr; params.num_m_blocks_ptr = use_prepare_varlen ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded : nullptr; params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? static_cast(tile_count_semaphore.data_ptr()) + b_rounded * 2 : nullptr; // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? static_cast(tile_count_semaphore.data_ptr()) + head_swizzle_offset : nullptr; params.tile_count_semaphore = scheduler_needs_semaphore ? static_cast(tile_count_semaphore.data_ptr()) + tile_count_semaphore_offset : nullptr; params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { STD_TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); STD_TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); STD_TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); Tensor q_v = q_v_.value(); STD_TORCH_CHECK(q_v.scalar_type() == q_type, "q_v must have the same dtype as query"); CHECK_DEVICE(q_v); STD_TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); } params.qv_ptr = q_v.data_ptr(); // All stride are in elements, not bytes. params.qv_row_stride = q_v.stride(-3); params.qv_head_stride = q_v.stride(-2); if (!is_varlen_q) { params.qv_batch_stride = q_v.stride(0); } } if (rotary_cos_.has_value()) { STD_TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); auto rotary_cos = rotary_cos_.value(); CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos); params.rotary_dim = rotary_cos.size(1) * 2; STD_TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); STD_TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); const int seqlen_ro = rotary_cos.size(0); if (paged_KV) { STD_TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); } CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); STD_TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); STD_TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); auto rotary_sin = rotary_sin_.value(); CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin); CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); STD_TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); params.rotary_cos_ptr = rotary_cos.data_ptr(); params.rotary_sin_ptr = rotary_sin.data_ptr(); params.is_rotary_interleaved = is_rotary_interleaved; if (seqlens_rotary_.has_value()) { Tensor seqlens_rotary = seqlens_rotary_.value(); CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary); STD_TORCH_CHECK(seqlens_rotary.scalar_type() == torch::headeronly::ScalarType::Int, "seqlens_rotary must have dtype torch.int32"); CHECK_SHAPE(seqlens_rotary, batch_size); params.seqlens_rotary = static_cast(seqlens_rotary.data_ptr()); } } else { params.rotary_dim = 0; } if (kv_batch_idx_.has_value()) { auto kv_batch_idx = kv_batch_idx_.value(); CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx); STD_TORCH_CHECK(kv_batch_idx.scalar_type() == torch::headeronly::ScalarType::Int, "kv_batch_idx must have dtype int32"); params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); } Tensor out_accum, softmax_lse_accum; auto outaccum_type = torch::headeronly::ScalarType::Float; if (params.num_splits > 1) { STD_TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { out_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, std::make_optional(outaccum_type)); softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, batch_size, num_heads, seqlen_q}, std::make_optional(torch::headeronly::ScalarType::Float)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { out_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q, head_size_v}, std::make_optional(outaccum_type)); softmax_lse_accum = torch::stable::new_empty(q, {params.num_splits, num_heads, total_q}, std::make_optional(torch::headeronly::ScalarType::Float)); } params.is_fp32 = false; params.oaccum_ptr = out_accum.data_ptr(); params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); params.oaccum_split_stride = out_accum.stride(0); params.oaccum_row_stride = out_accum.stride(-2); params.oaccum_head_stride = out_accum.stride(-3); params.lseaccum_split_stride = softmax_lse_accum.stride(0); params.lseaccum_head_stride = softmax_lse_accum.stride(-2); } if (q_type == torch::headeronly::ScalarType::Float8_e4m3fn) { if (q_descale_.has_value()) { auto q_descale = q_descale_.value(); CHECK_DEVICE(q_descale); CHECK_SHAPE(q_descale, batch_size, num_heads_k); params.q_descale_ptr = static_cast(q_descale.data_ptr()); params.q_descale_batch_stride = q_descale.stride(0); params.q_descale_head_stride = q_descale.stride(1); } else { params.q_descale_ptr = nullptr; } if (k_descale_.has_value()) { auto k_descale = k_descale_.value(); CHECK_DEVICE(k_descale); CHECK_SHAPE(k_descale, batch_size, num_heads_k); params.k_descale_ptr = static_cast(k_descale.data_ptr()); params.k_descale_batch_stride = k_descale.stride(0); params.k_descale_head_stride = k_descale.stride(1); } else { params.k_descale_ptr = nullptr; } if (v_descale_.has_value()) { auto v_descale = v_descale_.value(); CHECK_DEVICE(v_descale); CHECK_SHAPE(v_descale, batch_size, num_heads_k); params.v_descale_ptr = static_cast(v_descale.data_ptr()); params.v_descale_batch_stride = v_descale.stride(0); params.v_descale_head_stride = v_descale.stride(1); } else { params.v_descale_ptr = nullptr; } } #ifdef FLASHATTENTION_DISABLE_LOCAL STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif #ifdef FLASHATTENTION_DISABLE_SPLIT STD_TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits."); #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA STD_TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa."); #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV STD_TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV."); #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV STD_TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV."); #endif if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd(params, stream); if (params.num_splits > 1) { if (out_type == torch::headeronly::ScalarType::BFloat16) { // Since we want output in BF16. Otherwise fwd_combine will output to FP16 params.is_bf16 = true; } // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1 // and seqlen = total_q, and don't need to dispatch to Varlen there. // However, with dynamic split, each row needs to know which batch it belongs to // to read the number of splits, so we just use the varlen version of combine kernel. // if (is_varlen_q && !seqused_q_.has_value()) { // if (is_varlen_q) { // params.b = 1; // params.seqlen_q = total_q; // } // This will zero out the semaphore if needed run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case auto slice = torch::stable::narrow(tile_count_semaphore, 0, params.tile_count_semaphore_offset, 1); torch::stable::zero_(slice); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. torch::stable::zero_(out); torch::stable::fill_(softmax_lse, std::numeric_limits::infinity()); } // return {out, softmax_lse}; return {out, softmax_lse, out_accum, softmax_lse_accum}; } #ifdef FLASHATTENTION_DISABLE_BACKWARD void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { STD_TORCH_CHECK(false, "Flash-Attention was built with backward disabled"); } #else template void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { if (!params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } #endif #else STD_TORCH_CHECK(false, "This flash attention build does not support FP16."); #endif } else { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d_rounded == 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 if (params.d_rounded == 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 if (params.d_rounded == 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d_rounded == 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 if (params.d_rounded == 256) { return run_mha_bwd_(params, stream); } #endif } } void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { // FP16_SWITCH(!params.is_bf16, [&] { // HEADDIM_SWITCH(params.d, [&] { // run_mha_bwd_(params, stream); // }); // }); ARCH_SWITCH(params.arch, Arch, [&] { SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] { run_mha_bwd_constexpr(params, stream); }); }); } #endif // b: batch_size // s_q: seqlen_q // s_k: seqlen_k // h: num_heads // h_k: num_heads_k // d: head_size std::tuple mha_bwd( Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q std::optional dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q std::optional dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k std::optional dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k std::optional cu_seqlens_q_, // b+1 std::optional cu_seqlens_k_, // b+1 std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. std::optional max_seqlen_q_, std::optional max_seqlen_k_, std::optional softmax_scale_, bool is_causal, int64_t window_size_left, int64_t window_size_right, double softcap, bool deterministic, int64_t sm_margin ) { #ifdef FLASHATTENTION_DISABLE_BACKWARD STD_TORCH_CHECK(false, "This flash attention build does not support backward."); #endif auto dprops = get_device_prop(); bool is_sm8x = dprops->major >= 8; STD_TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); auto q_type = q.scalar_type(); STD_TORCH_CHECK(q_type == torch::headeronly::ScalarType::Half || q_type == torch::headeronly::ScalarType::BFloat16, "FlashAttention only support fp16 and bf16 data type"); STD_TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); STD_TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); STD_TORCH_CHECK(out.scalar_type() == q_type, "query and out must have the same dtype"); STD_TORCH_CHECK(dout.scalar_type() == q_type, "query and dout must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); STD_TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); STD_TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); Tensor cu_seqlens_q; bool const is_varlen_q = cu_seqlens_q_.has_value(); if (is_varlen_q) { cu_seqlens_q = cu_seqlens_q_.value(); CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q); STD_TORCH_CHECK(cu_seqlens_q.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_q must have dtype torch.int32"); STD_TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided"); } Tensor cu_seqlens_k; bool const is_varlen_k = cu_seqlens_k_.has_value(); if (is_varlen_k) { cu_seqlens_k = cu_seqlens_k_.value(); CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k); STD_TORCH_CHECK(cu_seqlens_k.scalar_type() == torch::headeronly::ScalarType::Int, "cu_seqlens_k must have dtype torch.int32"); STD_TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided"); } // This is what we will template on bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value(); #ifdef FLASHATTENTION_DISABLE_VARLEN STD_TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen."); #endif // auto const sizes = q.sizes(); int const batch_size = !is_varlen_q ? q.size(0) : cu_seqlens_q.size(0) - 1; int const seqlen_q = !is_varlen_q ? q.size(1) : max_seqlen_q_.value(); int const total_q = !is_varlen_q ? batch_size * q.size(1) : q.size(0); int const num_heads = q.size(-2); int const head_size = q.size(-1); int const head_size_v = v.size(-1); int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value(); int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0); int const num_heads_k = k.size(-2); STD_TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); STD_TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8"); int const max_headdim = get_max_headdim(); STD_TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); STD_TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); double softmax_scale = 1.0 / sqrt(double(head_size)); if (softmax_scale_.has_value()) { softmax_scale = softmax_scale_.value(); } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM if (window_size_left >= seqlen_k - 1) { window_size_left = -1; } if (window_size_right >= seqlen_q - 1) { window_size_right = -1; } if (is_causal) { window_size_right = 0; } // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA). is_causal = window_size_left < 0 && window_size_right == 0; int const arch = dprops->major * 10 + dprops->minor; int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v)); int const head_size_v_rounded = head_size_rounded; STD_TORCH_CHECK(!deterministic || head_size_rounded < 256, "Deterministic backward not supported for hdim 256."); // Very important that these match the kernel configs bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128) : (head_size_rounded <= 96 ? 64 : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80) : 64)); int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64; int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32; int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80); int const kBlockN_sm90 = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 96 : 80); int const kBlockN_sm80 = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 80 : 64); int const kBlockN_sm86 = head_size_rounded <= 64 ? 128 : (head_size_rounded <= 96 ? 128 : (head_size_rounded <= 128 ? 96 : (head_size_rounded <= 192 ? 64 : 64))); int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM); int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN); int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM); int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN); if (!is_varlen_q) { CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v); } else { CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(out, total_q, num_heads, head_size_v); CHECK_SHAPE(dout, total_q, num_heads, head_size_v); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); } if (!is_varlen_k) { CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } if (seqused_q_.has_value()){ auto seqused_q = seqused_q_.value(); STD_TORCH_CHECK(seqused_q.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_q must have dtype int32"); CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q); CHECK_SHAPE(seqused_q, batch_size); } if (seqused_k_.has_value()){ auto seqused_k = seqused_k_.value(); STD_TORCH_CHECK(seqused_k.scalar_type() == torch::headeronly::ScalarType::Int, "seqused_k must have dtype int32"); CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k); CHECK_SHAPE(seqused_k, batch_size); } Tensor dq, dk, dv; if (dq_.has_value()) { dq = dq_.value(); STD_TORCH_CHECK(dq.scalar_type() == q_type, "dq must have the same dtype as q"); CHECK_DEVICE(dq); STD_TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); if (!is_varlen_q) { CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); } else { CHECK_SHAPE(dq, total_q, num_heads, head_size); } } else { dq = torch::stable::empty_like(q); } if (dk_.has_value()) { dk = dk_.value(); STD_TORCH_CHECK(dk.scalar_type() == q_type, "dk must have the same dtype as q"); CHECK_DEVICE(dk); STD_TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); if (!is_varlen_k) { CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); } else { CHECK_SHAPE(dk, total_k, num_heads_k, head_size); } } else { dk = torch::stable::empty_like(k); } if (dv_.has_value()) { dv = dv_.value(); STD_TORCH_CHECK(dv.scalar_type() == q_type, "dv must have the same dtype as q"); CHECK_DEVICE(dv); STD_TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); if (!is_varlen_k) { CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v); } } else { dv = torch::stable::empty_like(v); } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_device_guard(q); // auto opts = q.options(); // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 Tensor softmax_d, softmax_lse_log2; if (!is_varlen) { // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64 softmax_d = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); softmax_lse_log2 = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); } else { softmax_d = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); softmax_lse_log2 = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); } Tensor dq_accum, dk_accum, dv_accum; if (!is_varlen) { dq_accum = torch::stable::new_empty(q, {batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); } else { dq_accum = torch::stable::new_empty(q, {num_heads, total_q_padded_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); } if (num_heads_k != num_heads) { // MQA / GQA if (!is_varlen) { dk_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); dk_accum = torch::stable::fill_(dk_accum, 0.0); dv_accum = torch::stable::new_empty(q, {batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); dv_accum = torch::stable::fill_(dv_accum, 0.0); } else { dk_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); dk_accum = torch::stable::fill_(dk_accum, 0.0); dv_accum = torch::stable::new_empty(q, {num_heads_k, total_k_padded_rounded, head_size_v_rounded}, std::make_optional(torch::headeronly::ScalarType::Float)); dv_accum = torch::stable::fill_(dv_accum, 0.0); } } Flash_bwd_params params; set_params_dgrad(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, q, k, v, out, dout, dq, dk, dv, !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(), !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(), seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr, seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr, dq_accum.data_ptr(), num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr, num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr, softmax_lse.data_ptr(), softmax_d.data_ptr(), /*p_dropout=*/0.f, softmax_scale, window_size_left, window_size_right, 0, // attention_chunk softcap, deterministic, sm_margin); params.total_q = total_q; params.total_k = total_k; params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr(); params.dv = head_size_v; params.dv_rounded = head_size_v_rounded; // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::headeronly::ScalarType::Int)) : torch::empty({1}, opts.dtype(torch::headeronly::ScalarType::Int)); // params.tile_count_semaphore = static_cast(tile_count_semaphore.data_ptr()); // Will be zero'ed out in the backward preprocess kernel Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int)); params.dq_semaphore = static_cast(dq_semaphore.data_ptr()); Tensor dk_semaphore, dv_semaphore; if (num_heads_k != num_heads && params.deterministic) { // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); params.dk_semaphore = static_cast(dk_semaphore.data_ptr()); params.dv_semaphore = static_cast(dv_semaphore.data_ptr()); } #ifdef FLASHATTENTION_DISABLE_LOCAL STD_TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention."); #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP STD_TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping."); #endif if (total_q > 0 && total_k > 0 && num_heads_k > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); cudaStream_t stream = static_cast(stream_ptr); run_mha_bwd(params, stream); } else if (total_k > 0 && num_heads_k > 0) { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. torch::stable::zero_(dk); torch::stable::zero_(dv); torch::stable::zero_(softmax_d); } else if (total_q > 0 && num_heads_k > 0) { torch::stable::zero_(dq); torch::stable::zero_(softmax_d); } return { softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; } std::tuple mha_combine(Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads std::optional out_, // batch_size x seqlen x num_heads x head_size std::optional out_dtype_ ) { auto dprops = get_device_prop(); bool is_sm8x = dprops->major >= 8; STD_TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer."); auto out_partial_type = out_partial.scalar_type(); STD_TORCH_CHECK(out_partial_type == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); STD_TORCH_CHECK(lse_partial.scalar_type() == torch::headeronly::ScalarType::Float, "Attention combine function only support fp32 data type"); CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial); STD_TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension"); STD_TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension"); // const auto sizes = out_partial.sizes(); const int num_splits = out_partial.size(0); const int batch_size = out_partial.size(1); const int seqlen = out_partial.size(2); const int num_heads = out_partial.size(3); const int head_size_og = out_partial.size(4); STD_TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads); int const alignment = 4; Tensor out_partial_padded; auto pad = [](Tensor x, int alignment) { return x.size(-1) % alignment == 0 ? x : torch::stable::pad(x, {0, alignment - x.size(-1) % alignment}); }; out_partial_padded = pad(out_partial, alignment); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size = round_multiple(head_size_og, alignment); // auto opts = out_partial.options(); torch::headeronly::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type()); STD_TORCH_CHECK(out_type == torch::headeronly::ScalarType::Float || out_type == torch::headeronly::ScalarType::BFloat16 || out_type == torch::headeronly::ScalarType::Half, "Output type must be FP32, FP16 or BF16"); Tensor out; if (out_.has_value()) { out = out_.value(); STD_TORCH_CHECK(out.scalar_type() == out_type); CHECK_DEVICE(out); STD_TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og); if (head_size_og % alignment != 0) { out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); } } else { out = torch::stable::new_empty(out_partial, {batch_size, seqlen, num_heads, head_size}, std::make_optional(out_type)); } // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing auto device_guard = make_device_guard(out_partial); auto softmax_lse = torch::stable::new_empty(out_partial, {batch_size, num_heads, seqlen}, std::make_optional(torch::headeronly::ScalarType::Float)); softmax_lse = torch::stable::transpose(softmax_lse, 1, 2); Flash_fwd_params params {}; // Need to reset the params to set everything to zero params.is_fp32 = out_type == torch::headeronly::ScalarType::Float; params.is_bf16 = out_type == torch::headeronly::ScalarType::BFloat16; params.oaccum_ptr = out_partial_padded.data_ptr(); params.softmax_lseaccum_ptr = lse_partial.data_ptr(); params.o_ptr = out.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr(); params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); params.oaccum_head_stride = out_partial_padded.stride(3); params.oaccum_batch_stride = out_partial_padded.stride(1); params.lseaccum_split_stride = lse_partial.stride(0); params.lseaccum_head_stride = lse_partial.stride(3); params.lseaccum_batch_stride = lse_partial.stride(1); params.o_row_stride = out.stride(1); params.o_head_stride = out.stride(2); params.o_batch_stride = out.stride(0); params.arch = dprops->major * 10 + dprops->minor; if (seqlen > 0 && batch_size > 0) { auto device_idx = torch::stable::accelerator::getCurrentDeviceIndex(); void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_idx, &stream_ptr)); cudaStream_t stream = static_cast(stream_ptr); run_mha_fwd_combine(params, stream, false /*enable_pdl*/); } Tensor out_padded = out; if (head_size_og % alignment != 0) { out = torch::stable::narrow(out, -1, 0, head_size_og); // if (out_.has_value()) { out_.value().copy_(out); } } return {out, softmax_lse}; } void boxed_mha_fwd( StableIValue* stack, uint64_t num_args, uint64_t num_outputs ) { auto q = to(stack[0]); auto k = to(stack[1]); auto v = to(stack[2]); auto k_new = to>(stack[3]); auto v_new = to>(stack[4]); auto q_v = to>(stack[5]); auto out = to>(stack[6]); auto cu_seqlens_q = to>(stack[7]); auto cu_seqlens_k = to>(stack[8]); auto cu_seqlens_k_new = to>(stack[9]); auto seqused_q = to>(stack[10]); auto seqused_k = to>(stack[11]); auto max_seqlen_q = to>(stack[12]); auto max_seqlen_k = to>(stack[13]); auto page_table = to>(stack[14]); auto kv_batch_idx = to>(stack[15]); auto leftpad_k = to>(stack[16]); auto rotary_cos = to>(stack[17]); auto rotary_sin = to>(stack[18]); auto seqlens_rotary = to>(stack[19]); auto q_descale = to>(stack[20]); auto k_descale = to>(stack[21]); auto v_descale = to>(stack[22]); auto softmax_scale = to>(stack[23]); auto is_causal = to(stack[24]); auto window_size_left = to(stack[25]); auto window_size_right = to(stack[26]); auto attention_chunk = to(stack[27]); auto softcap = to(stack[28]); auto is_rotary_interleaved = to(stack[29]); auto scheduler_metadata = to>(stack[30]); auto num_splits = to(stack[31]); auto pack_gqa = to>(stack[32]); auto sm_margin = to(stack[33]); auto [out_, softmax_lse, out_accum, softmax_lse_accum] = mha_fwd(q, k, v, k_new, v_new, q_v, out, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, is_causal, window_size_left, window_size_right, attention_chunk, softcap, is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin); stack[0] = from(out_); stack[1] = from(softmax_lse); stack[2] = from(out_accum); stack[3] = from(softmax_lse_accum); } void boxed_mha_bwd( StableIValue* stack, uint64_t num_args, uint64_t num_outputs ) { auto dout = to(stack[0]); auto q = to(stack[1]); auto k = to(stack[2]); auto v = to(stack[3]); auto out = to(stack[4]); auto softmax_lse = to(stack[5]); auto dq = to>(stack[6]); auto dk = to>(stack[7]); auto dv = to>(stack[8]); auto cu_seqlens_q = to>(stack[9]); auto cu_seqlens_k = to>(stack[10]); auto seqused_q = to>(stack[11]); auto seqused_k = to>(stack[12]); auto max_seqlen_q = to>(stack[13]); auto max_seqlen_k = to>(stack[14]); auto softmax_scale = to>(stack[15]); auto is_causal = to(stack[16]); auto window_size_left = to(stack[17]); auto window_size_right = to(stack[18]); auto softcap = to(stack[19]); auto deterministic = to(stack[20]); auto sm_margin = to(stack[21]); auto [softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum] = mha_bwd(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin); stack[0] = from(softmax_d); stack[1] = from(softmax_lse_log2); stack[2] = from(dq_accum); stack[3] = from(dk_accum); stack[4] = from(dv_accum); } void boxed_mha_combine( StableIValue* stack, uint64_t num_args, uint64_t num_outputs ) { auto out_partial = to(stack[0]); auto lse_partial = to(stack[1]); auto out = to>(stack[2]); auto out_dtype = to>(stack[3]); auto [out_, softmax_lse] = mha_combine(out_partial, lse_partial, out, out_dtype); stack[0] = from(out_); stack[1] = from(softmax_lse); } void boxed_mha_fwd_get_scheduler_metadata( StableIValue* stack, uint64_t num_args, uint64_t num_outputs ) { auto batch_size = to(stack[0]); auto max_seqlen_q = to(stack[1]); auto max_seqlen_k = to(stack[2]); auto num_heads = to(stack[3]); auto num_heads_k = to(stack[4]); auto headdim = to(stack[5]); auto headdim_v = to(stack[6]); auto qkv_dtype = to(stack[7]); auto seqused_k = to(stack[8]); auto cu_seqlens_q = to>(stack[9]); auto cu_seqlens_k = to>(stack[10]); auto cu_seqlens_k_new = to>(stack[11]); auto seqused_q = to>(stack[12]); auto leftpad_k = to>(stack[13]); auto page_size = to>(stack[14]); auto max_seqlen_k_new = to(stack[15]); auto is_causal = to(stack[16]); auto window_size_left = to(stack[17]); auto window_size_right = to(stack[18]); auto attention_chunk = to(stack[19]); auto has_softcap = to(stack[20]); auto num_splits = to(stack[21]); auto pack_gqa = to>(stack[22]); auto sm_margin = to(stack[23]); auto scheduler_metadata = mha_fwd_get_scheduler_metadata(batch_size, max_seqlen_q, max_seqlen_k, num_heads, num_heads_k, headdim, headdim_v, qkv_dtype, seqused_k, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, leftpad_k, page_size, max_seqlen_k_new, is_causal, window_size_left, window_size_right, attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin); stack[0] = from(scheduler_metadata); } STABLE_TORCH_LIBRARY(flash_attn_3, m) { m.def("fwd(" "Tensor q," "Tensor k," "Tensor v," "Tensor(k_new!)? k_new = None," "Tensor(v_new!)? v_new = None," "Tensor? q_v = None," "Tensor(out!)? out = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "Tensor? page_table = None," "Tensor? kv_batch_idx = None," "Tensor? leftpad_k = None," "Tensor? rotary_cos = None," "Tensor? rotary_sin = None," "Tensor? seqlens_rotary = None," "Tensor? q_descale = None," "Tensor? k_descale = None," "Tensor? v_descale = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "float softcap = 0.0," "bool is_rotary_interleaved = False," "Tensor? scheduler_metadata = None," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)"); m.def("bwd(" "Tensor dout," "Tensor q," "Tensor k," "Tensor v," "Tensor out," "Tensor softmax_lse," "Tensor(dq!)? dq = None," "Tensor(dk!)? dk = None," "Tensor(dv!)? dv = None," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? seqused_q = None," "Tensor? seqused_k = None," "int? max_seqlen_q = None," "int? max_seqlen_k = None," "float? softmax_scale = None," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "float softcap = 0.0," "bool deterministic = False," "int sm_margin = 0) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("fwd_combine(" "Tensor out_partial," "Tensor lse_partial," "Tensor(out!)? out = None," "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)"); m.def("get_scheduler_metadata(" "int batch_size," "int max_seqlen_q," "int max_seqlen_k," "int num_heads," "int num_heads_k," "int headdim," "int headdim_v," "ScalarType qkv_dtype," "Tensor seqused_k," "Tensor? cu_seqlens_q = None," "Tensor? cu_seqlens_k = None," "Tensor? cu_seqlens_k_new = None," "Tensor? seqused_q = None," "Tensor? leftpad_k = None," "int? page_size = None," "int max_seqlen_k_new = 0," "bool is_causal = False," "int window_size_left = -1," "int window_size_right = -1," "int attention_chunk = 0," "bool has_softcap = False," "int num_splits = 0," "bool? pack_gqa = None," "int sm_margin = 0) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { m.impl("fwd", &boxed_mha_fwd); m.impl("bwd", &boxed_mha_bwd); m.impl("fwd_combine", &boxed_mha_combine); m.impl("get_scheduler_metadata", &boxed_mha_fwd_get_scheduler_metadata); } ================================================ FILE: hopper/flash_attn_interface.py ================================================ # Copyright (c) 2023, Tri Dao. from typing import Optional, Union, List, Tuple import os import torch import torch.nn as nn import warnings USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if not USE_TRITON_ROCM and getattr(torch.version, 'hip', None) is not None: try: import flash_attn_3._C except ImportError: warnings.warn("flash_attn_3._C (which has ROCm/HIP kernels) not found, falling back to Triton implementation") USE_TRITON_ROCM = True if USE_TRITON_ROCM: from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu else: # isort: off # We need to import the CUDA kernels after importing torch import flash_attn_3._C # Registers operators with PyTorch # isort: on flash_attn_3_gpu = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x def round_multiple(x, m): return (x + m - 1) // m * m def round_up_headdim(head_size: int) -> int: from flash_attn_config import CONFIG if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]: if head_size <= 64: return 64 if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]: if head_size <= 96: return 96 if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]: if head_size <= 128: return 128 if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]: if head_size <= 192: return 192 if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]: if head_size <= 256: return 256 return 256 @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, k_new: Optional[torch.Tensor] = None, v_new: Optional[torch.Tensor] = None, qv: Optional[torch.Tensor] = None, out_: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, kv_batch_idx: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, seqlens_rotary: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, attention_chunk: int = 0, softcap: float = 0.0, rotary_interleaved: bool = True, scheduler_metadata: Optional[torch.Tensor] = None, num_splits: int = 1, pack_gqa: Optional[bool] = None, sm_margin: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new) ] seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)] page_table, kv_batch_idx, leftpad_k = [ maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k) ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_gpu.fwd( q, k, v, k_new, v_new, qv, out_, cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale, softmax_scale, causal, window_size_left, window_size_right, attention_chunk, softcap, rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin, ) if out_accum is None: out_accum = torch.tensor([], device=out.device) if softmax_lse_accum is None: softmax_lse_accum = torch.tensor([], device=out.device) return out, softmax_lse, out_accum, softmax_lse_accum @torch.library.register_fake("flash_attn_3::_flash_attn_forward") def _flash_attn_forward_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, k_new: Optional[torch.Tensor] = None, v_new: Optional[torch.Tensor] = None, qv: Optional[torch.Tensor] = None, out_: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, kv_batch_idx: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, seqlens_rotary: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, attention_chunk: int = 0, softcap: float = 0.0, rotary_interleaved: bool = True, scheduler_metadata: Optional[torch.Tensor] = None, num_splits: int = 1, pack_gqa: Optional[bool] = None, sm_margin: int = 0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Symbolic fake implementation of flash attention forward. Returns tensors with the correct shapes and dtypes without actual computation. """ # Determine if we're in varlen mode is_varlen_q = cu_seqlens_q is not None # Get dimensions from query tensor if is_varlen_q: # varlen mode: q is (total_q, num_heads, head_size) total_q, num_heads, head_size = q.shape batch_size = cu_seqlens_q.shape[0] - 1 if max_seqlen_q is None: raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided") seqlen_q = max_seqlen_q else: # batch mode: q is (batch_size, seqlen_q, num_heads, head_size) batch_size, seqlen_q, num_heads, head_size = q.shape total_q = batch_size * q.shape[1] # Get value head dimension head_size_v = v.shape[-1] # Determine output dtype (FP8 inputs produce BF16 outputs) q_type = q.dtype if q_type == torch.float8_e4m3fn: out_dtype = torch.bfloat16 else: out_dtype = q_type # Create output tensor if out_ is not None: # If out_ is provided, _flash_attn_forward becomes non-functional raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.") if is_varlen_q: out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) else: out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device) # Create softmax_lse tensor if is_varlen_q: softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device) else: softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) # TODO(guilhermeleobas): Implement "get_num_splits" # There's an heuristic to compute num_splits when "num_splits <= 0" # assert that num_splits is > 0 for now if num_splits <= 0: raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}") if num_splits > 1: if is_varlen_q: out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device) softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device) else: out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device) softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device) else: # Tensors are not set when num_splits < 1 out_accum = torch.tensor([], device=out.device) softmax_lse_accum = torch.tensor([], device=out.device) return out, softmax_lse, out_accum, softmax_lse_accum @torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") def _flash_attn_backward( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, sequed_q: Optional[torch.Tensor] = None, sequed_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, deterministic: bool = False, sm_margin: int = 0, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, sequed_q, sequed_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, sm_margin, ) return softmax_d @torch.library.register_fake("flash_attn_3::_flash_attn_backward") def _flash_attn_backward_fake( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, sequed_q: Optional[torch.Tensor] = None, sequed_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, deterministic: bool = False, sm_margin: int = 0, ) -> torch.Tensor: is_varlen_q = cu_seqlens_q is not None is_varlen_k = cu_seqlens_q is not None is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None if not is_varlen_q: batch_size = q.size(0) seqlen_q = q.size(1) seqlen_k = k.size(1) total_q = batch_size * q.size(1) else: batch_size = cu_seqlens_q.size(0) - 1 total_q = q.size(0) seqlen_q = max_seqlen_q seqlen_k = max_seqlen_k if window_size_left >= seqlen_k - 1: window_size_left = -1 if window_size_right >= seqlen_q - 1: window_size_right = -1 if is_causal: window_size_right = 0 is_causal = window_size_left < 0 and window_size_right == 0 head_size = q.size(-1) head_size_v = v.size(-1) head_size_rounded = round_up_headdim(max(head_size, head_size_v)) # Hopper gpus uses cuda compute capabilities 9.0 cap = torch.cuda.get_device_capability(q.device) arch = cap[0] * 10 + cap[1] is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal if head_size_rounded <= 64: kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128 elif head_size_rounded <= 96: kBlockM_sm90 = 64 elif head_size_rounded <= 128: kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80 else: kBlockM_sm90 = 64 kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64 kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32 if arch >= 90: kBlockM = kBlockM_sm90 elif arch == 86 or arch == 89: kBlockM = kBlockM_sm86 else: kBlockM = kBlockM_sm80 num_heads = q.shape[-2] seqlen_q_rounded = round_multiple(seqlen_q, kBlockM) total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM) dq = torch.empty_like(q) if dq is None else dq dk = torch.empty_like(k) if dk is None else dk dv = torch.empty_like(v) if dv is None else dv if not is_varlen: softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device) else: softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device) return softmax_d def setup_context(ctx, inputs, output): q, k, v = inputs[:3] out, softmax_lse, _, _ = output ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = inputs[-11] ctx.causal = inputs[-10] ctx.window_size = [inputs[-9], inputs[-8]] ctx.attention_chunk = inputs[-7] ctx.softcap = inputs[-6] ctx.sm_margin = inputs[-1] def _backward(ctx, dout, *grads): q, k, v, out, softmax_lse = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k, None, None, # sequed_q, sequed_k, None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, False, # deterministic ctx.sm_margin, ) return dq, dk, dv, *((None,) * 21) _flash_attn_forward.register_autograd(_backward, setup_context=setup_context) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward( ctx, qkv, softmax_scale, causal, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, sm_margin=0, return_softmax=False, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) if qkv.dim() == 5: assert qkv.shape[-3] == 3 q, k, v = qkv.unbind(dim=-3) else: assert qkv.dim() == 4 assert num_heads_q is not None num_heads_k = (qkv.shape[2] - num_heads_q) // 2 assert num_heads_k * 2 + num_heads_q == qkv.shape[2] q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new None, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, sm_margin=sm_margin, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.ndim = qkv.dim() ctx.sm_margin = sm_margin return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" if ctx.ndim == 5: qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dq, dk, dv = dqkv.unbind(dim=-3) else: num_heads_q = q.shape[2] num_heads_k = k.shape[2] qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2) _flash_attn_backward( dout, q, k, v, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k, None, None, # sequed_q, sequed_k, None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension return dqkv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, softmax_scale, causal, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new qv, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k None, None, # max_seqlen_q/k None, None, None, # page_table, kv_batch_idx, leftpad_k, None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.save_for_backward(q, k, v, out, softmax_lse) ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse = ctx.saved_tensors assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k, None, None, # sequed_q, sequed_k, None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, return_softmax=False, ): if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new qv, # qv None, # out cu_seqlens_q, cu_seqlens_k, None, # cu_seqlens_k_new seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, None, None, None, # page_table, kv_batch_idx, leftpad_k, None, None, None, # rotary_cos/sin, seqlens_rotary q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size ctx.attention_chunk = attention_chunk ctx.softcap = softcap ctx.deterministic = deterministic ctx.sm_margin = sm_margin return (out, softmax_lse) if return_softmax else out @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk" dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, ctx.max_seqlen_q, ctx.max_seqlen_k, dq, dk, dv, ctx.softmax_scale, ctx.causal, ctx.window_size[0], ctx.window_size[1], ctx.softcap, ctx.deterministic, ctx.sm_margin, ) dq = dq[..., : q.shape[-1]] # We could have padded the head dimension dk = dk[..., : k.shape[-1]] dv = dv[..., : v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( qkv, softmax_scale=None, causal=False, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, deterministic=False, num_heads_q=None, sm_margin=0, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. For multi-query and grouped-query attention (MQA/GQA), please see flash_attn_kvpacked_func and flash_attn_func. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). The output of softmax (possibly with different scaling). It also encodes the dropout pattern (negative means that location was dropped, nonnegative means it was kept). """ return FlashAttnQKVPackedFunc.apply( qkv, softmax_scale, causal, q_descale, k_descale, v_descale, window_size, attention_chunk, softcap, deterministic, num_heads_q, sm_margin, return_attn_probs, ) def flash_attn_func( q, k, v, softmax_scale=None, causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, return_attn_probs=False, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ return FlashAttnFunc.apply( q, k, v, softmax_scale, causal, qv, q_descale, k_descale, v_descale, window_size, attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, return_attn_probs, ) def flash_attn_varlen_func( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, seqused_q=None, seqused_k=None, softmax_scale=None, causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, deterministic=False, sm_margin=0, return_attn_probs=False, ): return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, causal, qv, q_descale, k_descale, v_descale, window_size, attention_chunk, softcap, num_splits, pack_gqa, deterministic, sm_margin, return_attn_probs, ) def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( q, k_cache, v_cache, k=None, v=None, qv=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, rotary_seqlens: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window attention_chunk=0, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, scheduler_metadata=None, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from k and v. This is useful for incremental decoding: you can pass in the cached keys/values from the previous step, and update them with the new keys/values from the current step, and do attention with the updated cache, all in 1 kernel. If you pass in k / v, you must make sure that the cache is large enough to hold the new values. For example, the KV cache could be pre-allocated with the max sequence length, and you can use cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: 1 1 1 1 0 1 1 1 1 1 If seqlen_q = 5 and seqlen_k = 2, the causal mask is: 0 0 0 0 0 0 1 0 1 1 If the row of the mask is all zero, the output will be zero. If window_size != (-1, -1), implements sliding window local attention. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. Note: Does not support backward pass. Arguments: q: (batch_size, seqlen, nheads, headdim) k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.). v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. qv [optional]: (batch_size, seqlen, nheads, headdim_v) rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). num_splits: int. If > 1, split the key/value into this many chunks along the sequence. If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) cache_seqlens = maybe_contiguous(cache_seqlens) out, softmax_lse, *rest = _flash_attn_forward( q, k_cache, v_cache, k, v, qv, None, # out cu_seqlens_q, None, # cu_seqlens_k cu_seqlens_k_new, None, # seqused_q cache_seqlens, max_seqlen_q, None, # max_seqlen_k page_table, cache_batch_idx, cache_leftpad, rotary_cos, rotary_sin, rotary_seqlens, q_descale, k_descale, v_descale, softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], attention_chunk=attention_chunk, softcap=softcap, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out def get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, cache_seqlens: torch.Tensor, qkv_dtype=torch.bfloat16, headdim_v=None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_size: Optional[int] = None, max_seqlen_k_new=0, causal=False, window_size=(-1, -1), # -1 means infinite context window attention_chunk=0, has_softcap=False, num_splits=0, # Can be tuned for speed pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication ): cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, cu_seqlens_q, None, # cu_seqlens_k cu_seqlens_k_new, None, # seqused_q cache_leftpad, page_size, max_seqlen_k_new, causal, window_size[0], window_size[1], attention_chunk, has_softcap, num_splits, pack_gqa, sm_margin, ) return scheduler_metadata ================================================ FILE: hopper/flash_bwd_kernel_sm80.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdSm80 { public: // Type Aliases static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; static constexpr bool Is_local = CollectiveMainloop_::Is_local; static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); static constexpr bool Varlen = CollectiveMainloop_::Varlen; // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; using ArchTag = typename CollectiveMainloop::ArchTag; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 80); using TileScheduler = TileScheduler_; using TileSchedulerArguments = typename flash::TileSchedulerArguments; using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{})); static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Kernel level shared memory storage struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { union { typename CollectiveMainloop::TensorStorage mainloop; typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; alignas(16) typename TileScheduler::SharedStorage smem_scheduler; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; // Kernel entry point API struct Params { MainloopParams mainloop{}; EpilogueParams epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerParams scheduler{}; }; // // Methods // // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; return { CollectiveMainloop::to_underlying_arguments(args.mainloop), CollectiveEpilogue::to_underlying_arguments(args.epilogue), hw_info, TileScheduler::to_underlying_arguments(args.scheduler) }; } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); SharedStorage& shared_storage = *reinterpret_cast(smem_buf); CollectiveMainloop mainloop; CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. TiledMmadKV tiled_mma_dKV; scheduler.init_consumer(); int warp_idx = cutlass::canonical_warp_idx_sync(); CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x, block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, threadIdx.x, block_coord); } else { epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } } }; } // namespace flash ================================================ FILE: hopper/flash_bwd_kernel_sm90.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdSm90 { public: // Type Aliases static constexpr bool Is_causal = CollectiveMainloop_::Is_causal; static constexpr bool Is_local = CollectiveMainloop_::Is_local; static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen); static constexpr bool Varlen = CollectiveMainloop_::Varlen; // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP; using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 90); using TileScheduler = TileScheduler_; using TileSchedulerArguments = typename flash::TileSchedulerArguments; using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32; static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160; // If you want to print from the producer warp, you'd need to increase the number of registers // Otherwise you'll get CUDA error. // static constexpr uint32_t LoadRegisterRequirement = 40; // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; // Kernel level shared memory storage struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { union { typename CollectiveMainloop::TensorStorage mainloop; typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; struct PipelineStorage : cute::aligned_struct<16> { alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV; alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q; alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do; alignas(16) typename TileScheduler::SharedStorage smem_scheduler; } pipelines; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; // Kernel entry point API struct Params { MainloopParams mainloop{}; EpilogueParams epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerParams scheduler{}; }; // // Methods // // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; return { CollectiveMainloop::to_underlying_arguments(args.mainloop), CollectiveEpilogue::to_underlying_arguments(args.epilogue), hw_info, TileScheduler::to_underlying_arguments(args.scheduler) }; } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO; using PipelineParams_dO = typename MainloopPipeline_dO::Params; using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; static constexpr bool Q_dO_same_stages = std::is_same_v; SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE; int warp_group_idx = cutlass::canonical_warp_group_idx(); pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer : MainloopPipeline::ThreadCategory::Consumer; pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumMmaThreads; if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/); } // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init(); MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{}); auto role_dO = warp_group_idx == 0 ? MainloopPipeline_dO::ThreadCategory::Producer : MainloopPipeline_dO::ThreadCategory::Consumer; PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers}; MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return(pipeline_params, pipeline_params_dO), ClusterShape{}); CollectiveMainloop mainloop; CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state(); for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord); } mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do); } else if (warp_idx_in_warpgroup == 1) { for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; mainloop.store_dq(params.mainloop, shared_storage, block_coord); } } } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); // Initialize matmul objects. TiledMmadKV tiled_mma_dKV; PipelineState smem_pipe_read; PipelineState_dO smem_pipe_read_do; mainloop.mma_init(); scheduler.init_consumer(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord_ = work_tile_info.get_block_coord(params.scheduler); auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_; cute::tuple block_coord = {n_block, bidh, bidb}; // dK and dV output accumulator. Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select(TileShape_MNK{})); bool tile_valid = mainloop.mma( params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do, tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage); if (tile_valid) { epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV, threadIdx.x - NumCopyThreads, block_coord); } else { epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord); } } epilogue.store_tail(); } } }; } // namespace flash ================================================ FILE: hopper/flash_bwd_launch_template.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/device_kernel.h" // For device_kernel #include "cutlass/kernel_launch.h" // For kernel_launch #include "cutlass/cluster_launch.hpp" // For ClusterLauncher #include "cuda_check.h" #include "static_switch.h" #include "flash.h" #include "flash_bwd_preprocess_kernel.h" #include "flash_bwd_postprocess_kernel.h" #include "tile_scheduler.hpp" #include "mainloop_bwd_sm90_tma_gmma_ws.hpp" #include "mainloop_bwd_sm80.hpp" #include "epilogue_bwd.hpp" #include "flash_bwd_kernel_sm90.h" #include "flash_bwd_kernel_sm80.h" using namespace cute; template void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time."); using ElementAccum = float; using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM); int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN); bool const is_varlen_q = params.cu_seqlens_q; bool const is_varlen_k = params.cu_seqlens_k; int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k; int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded; int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded; int batch_q = !is_varlen_q ? params.b : 1; int batch_k = !is_varlen_k ? params.b : 1; using TileShape_MK = cute::Shape, Int>; using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), {seqlen_q, params.dv, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dsoftmax_sum), {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum static_cast(params.softmax_lse_ptr), {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE static_cast(params.softmax_lse_log2_ptr), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum params.b, params.dq_semaphore, params.cu_seqlens_q, params.seqused_q }; typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); dim3 grid_m(num_m_block, params.h, params.b); CHECK_CUTLASS(cutlass::kernel_launch(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/)); using TileShape_MNK = cute::Shape, Int, Int>; using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80 static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80; static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1; using CollectiveMainloop = std::conditional_t< Arch >= 90, flash::CollectiveMainloopBwdSm90, flash::CollectiveMainloopBwdSm80 >; using CollectiveEpilogue = std::conditional_t< !GQA, flash::CollectiveEpilogueBwd= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>, flash::CollectiveEpilogueBwdGQA >; using Scheduler = std::conditional_t< Is_causal, flash::SingleTileBwdLPTScheduler, flash::SingleTileScheduler >; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90>, flash::enable_sm80_to_sm89> >; typename CollectiveMainloop::Arguments mainloop_args { static_cast(params.q_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), {seqlen_q, params.dv, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.softmax_lse_log2_ptr), {seqlen_q_rounded, params.h, batch_q}, // shape_LSE {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dsoftmax_sum), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum params.scale_softmax, params.window_size_left, params.window_size_right, 0 /*attention_chunk*/, params.softcap, params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; // The case work with GQA is ugly but idk how to fix it. typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!GQA ? params.dk_ptr : params.dk_accum_ptr), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK } else { return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK } else { return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV } else { return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), params.b, params.h, params.dk_semaphore, params.dv_semaphore, params.cu_seqlens_k, params.seqused_k, }; int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{})); num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{})); typename flash::TileSchedulerArguments scheduler_args { num_blocks_n, params.h, params.b, 1 /*num_splits*/, params.h / params.h_k, params.seqlen_k, params.seqlen_q, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k }; int device; cudaGetDevice(&device); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args }); dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); dim3 block_dims = AttnKernel::get_block_shape(); int smem_size = AttnKernel::SharedStorageSize; // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do)); // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds)); // int smem_size_dqacc = [&] { // if constexpr (Arch >= 90) { // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc)); // } else { // return 0; // } // }(); // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse)); // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum)); // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum); if constexpr (size(ClusterShape{}) > 1) { void const* kernel = (void const*) cutlass::device_kernel; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); CHECK_CUTLASS(cutlass::ClusterLauncher::launch( grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/)); } else { if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } CHECK_CUTLASS(cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/)); } using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ; typename PostprocessKernel::Arguments postprocess_args { static_cast(params.dq_accum_ptr), {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.dq_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_dQ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ params.scale_softmax, params.cu_seqlens_q, params.seqused_q }; typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b); int smem_size_postprocess = PostprocessKernel::SharedStorageSize; if (smem_size_postprocess >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); } CHECK_CUTLASS(cutlass::kernel_launch(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/)); if constexpr (GQA) { using TileShape_NK = cute::Shape, Int>; using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ; typename PostprocessKerneldKV::Arguments postprocess_dK_args { static_cast(params.dk_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum static_cast(params.dk_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK 1.f, params.cu_seqlens_k, params.seqused_k }; typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, params.seqused_k }; typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args); int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{})); dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b); int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize; if (smem_size_postprocess >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess)); } CHECK_CUTLASS(cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/)); CHECK_CUTLASS(cutlass::kernel_launch(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/)); } } template void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] { BOOL_SWITCH(params.h != params.h_k, GQA, [&] { BOOL_SWITCH(params.deterministic, Deterministic_, [&] { static constexpr bool Deterministic = Deterministic_ && kHeadDim < 256; // run_flash_bwd(params, stream); run_flash_bwd(params, stream); }); }); }); } template void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { if constexpr (Is_causal && Has_softcap) { // register spill with 128 x 128 run_mha_bwd_dispatch(params, stream); } else { // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block. run_mha_bwd_dispatch(params, stream); } } else if constexpr (Arch == 86 || Arch == 89) { run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } }); } template void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else if constexpr (Arch == 86 || Arch == 89) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } }); } template void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { if constexpr (Is_causal || Is_local || Has_softcap) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } } else if constexpr (Arch == 86 || Arch == 89) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } }); } template void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else if constexpr (Arch == 86 || Arch == 89) { run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } }); } template void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { if constexpr (Arch >= 90) { run_mha_bwd_dispatch(params, stream); } else if constexpr (Arch == 86 || Arch == 89) { run_mha_bwd_dispatch(params, stream); // run_mha_bwd_dispatch(params, stream); } else { run_mha_bwd_dispatch(params, stream); } }); } ================================================ FILE: hopper/flash_bwd_postprocess_kernel.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "cutlass/arch/barrier.h" #include "seqlen.h" #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdPostprocessConvertdQ { public: // Type Aliases using TileShape_MK = TileShape_MK_; using ArchTag = ArchTag_; static_assert(ArchTag::kMinComputeCapability >= 75); static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90; static constexpr uint32_t MaxThreadsPerBlock = kNThreads; static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); static constexpr int kHeadDim = get<1>(TileShape_MK{}); static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup"); static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup; using R2SLayoutAtomdQaccum = std::conditional_t< IsSm90, Layout, Int>>, Layout>> >; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, Layout>>{})); // Val layout, 1 or 4 vals per read using G2SLayoutAtomdQaccum = Layout>>; // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, G2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per read // We don't do bound checking for the gmem -> smem load so we just assert here. static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0); static constexpr int SmemdQaccumSize = size(TileShape_MK{}); using SmemLayoutdQaccumFlat = Layout>>; using SmemLayoutdQaccum = std::conditional_t< IsSm90, Layout, Int>>, Layout>> >; // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs, // then setting kBlockKSmem to 32 will cause "Static shape_div failure". // We want to treat it as 64 x 48, so kBlockKSmem should be 16. static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{}); static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16); static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1); using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{})); using SmemLayoutdQt = decltype(cute::composition(SmemLayoutdQ{}, make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})), make_stride(Int(TileShape_MK{})>{}, _1{})))); using SmemCopyAtomdQ = Copy_Atom< std::conditional_t< IsSm90, std::conditional_t, AutoVectorizingCopyWithAssumedAlignment<128> >, Element>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock)); static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per load struct SharedStorage : cute::aligned_struct<128> { cute::array_aligned> smem_dqacc; cute::array_aligned> smem_dq; alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) using StridedQ = cute::Stride; using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments struct Arguments { ElementAccum const* ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; Element* ptr_dQ; ShapedQ const shape_dQ; StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Kernel entry point API struct Params { ElementAccum const* ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; Element* ptr_dQ; ShapedQ const shape_dQ; StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { return { args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, args.ptr_dQ, args.shape_dQ, args.stride_dQ, args.softmax_scale, args.cu_seqlens, args.seqused }; } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MK{}); SharedStorage& shared_storage = *reinterpret_cast(smem_buf); Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{}); Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{}); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{}); Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{}); int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const bidh = blockIdx.y; int const bidb = blockIdx.z; flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused); bool const is_varlen = params.cu_seqlens; if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; } // Step 1: load dQaccum from gmem to smem Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) if constexpr (IsSm90) { // Use BulkCopy static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); auto bulk_copy = Copy_Traits{}; // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); } if (thread_idx == 0) { shared_storage.barrier_dQaccum.init(1 /*numThreads*/); shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); copy(bulk_copy.with(*reinterpret_cast(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat); } __syncthreads(); shared_storage.barrier_dQaccum.wait(0); } else { G2STiledCopydQaccum g2s_tiled_copy_dQaccum; auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum); Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum); cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s); __syncthreads(); } // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); } // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16 R2STiledCopydQaccum s2r_tiled_copy_dQaccum; auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum); TiledMma tiled_mma_dQ; Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select(TileShape_MK{})); // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); } // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); } // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); } CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum)); Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum); cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); #pragma unroll for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; } // Convert tdQrdQ from fp32 to fp16 Tensor rdQ = make_tensor_like(taccdQrdQaccum); flash::convert_type_out(taccdQrdQaccum, rdQ); // Step 3: Copy dQ from register to smem auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ); auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) // if (cute::thread0()) { print(smem_tiled_copy_dQ); } // if (cute::thread0()) { print(smem_thr_copy_dQ); } // if (cute::thread0()) { print(sdQ); } Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); __syncthreads(); // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) GmemTiledCopy gmem_tiled_copy_dQ; auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx); Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); Tensor tdQrdQ = make_fragment_like(tdQsdQ); Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{})); Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); #pragma unroll for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); } // Need to check OOB when reading from smem if kBlockM isn't evenly tiled static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM); // Step 5: Copy dQ from register to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM) ); } }; } // namespace flash ================================================ FILE: hopper/flash_bwd_preprocess_kernel.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "seqlen.h" #include "utils.h" namespace flash { using namespace cute; template class FlashAttnBwdPreprocess { public: // Type Aliases using TileShape_MK = TileShape_MK_; using ArchTag = ArchTag_; static_assert(std::is_same_v && ArchTag::kMinComputeCapability >= 75 || std::is_same_v && ArchTag::kMinComputeCapability >= 80 || std::is_same_v && ArchTag::kMinComputeCapability >= 89); static constexpr uint32_t MaxThreadsPerBlock = 256; static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int SharedStorageSize = 0; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); static constexpr int kBlockM = get<0>(TileShape_MK{}); static constexpr int kHeadDim = get<1>(TileShape_MK{}); // We want kBlockKGmem to be a power of 2 so that when we do the summing, // it's just between threads in the same warp static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per load static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum"); using GmemLayoutAtomAccum = Layout>>; using GmemTiledCopyAccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomAccum{}, Layout>>{})); // Val layout, 4 vals per store using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) using StrideO = cute::Stride; using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) using StridedPsum = cute::Stride<_1, int64_t, int64_t>; using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments struct Arguments { Element const* ptr_O; ShapeO const shape_O; StrideO const stride_O; Element const* ptr_dO; StrideO const stride_dO; float* ptr_dPsum; ShapedPsum const shape_dPsum; StridedPsum const stride_dPsum; float const* ptr_LSE; StridedPsum const stride_LSE; float *ptr_LSE_log2; StridedPsum const stride_LSE_log2; ElementAccum* ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; int num_batch; // We need this to know the size of dq_semaphore in case of varlen int* dq_semaphore; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Kernel entry point API struct Params { Element const* ptr_O; ShapeO const shape_O; StrideO const stride_O; Element const* ptr_dO; StrideO const stride_dO; float* ptr_dPsum; ShapedPsum const shape_dPsum; StridedPsum const stride_dPsum; float const* ptr_LSE; StridedPsum const stride_LSE; float* ptr_LSE_log2; StridedPsum const stride_LSE_log2; ElementAccum* ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; int num_batch; int* dq_semaphore; int const* cu_seqlens = nullptr; int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { return { args.ptr_O, args.shape_O, args.stride_O, args.ptr_dO, args.stride_dO, args.ptr_dPsum, args.shape_dPsum, args.stride_dPsum, args.ptr_LSE, args.stride_LSE, args.ptr_LSE_log2, args.stride_LSE_log2, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, args.num_batch, args.dq_semaphore, args.cu_seqlens, args.seqused }; } CUTLASS_DEVICE void operator()(Params const& params, [[maybe_unused]] char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MK{}); int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const bidh = blockIdx.y; int const bidb = blockIdx.z; flash::SeqlenInfo seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused); bool const is_varlen = Varlen && params.cu_seqlens; int const seqlen_o = seqlen_info.seqlen; if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K) auto shape_LSE = select<0, 2, 3>(params.shape_O); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0); Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape>{}, make_coord(m_block)); static_assert(kBlockM <= MaxThreadsPerBlock); float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY; GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor tOgO = gmem_thr_copy_O.partition_S(gO); Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO); // Construct identity layout for gO Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128) Tensor tOrO = make_fragment_like(tOgO); Tensor tOrdO = make_fragment_like(tOgdO); flash::copy( gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); flash::copy( gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM ); // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));} // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64)) Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout()))); Tensor tOrO_l = make_tensor(tOrO.data(), l); Tensor o_fp32 = make_tensor_like(tOrO_l); flash::convert_type_out(tOrO_l, o_fp32); Tensor tOrdO_l = make_tensor(tOrdO.data(), l); Tensor do_fp32 = make_tensor_like(tOrdO_l); flash::convert_type_out(tOrdO_l, do_fp32); // Sum across the last dimension Tensor dP_sum = make_tensor(make_shape(size<0>(o_fp32))); #pragma unroll for (int mi = 0; mi < size<0>(o_fp32); ++mi) { float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); #pragma unroll for (int ni = 1; ni < size<1>(o_fp32); ni++) { dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); } flash::SumOp sum_op; dP_sum(mi) = flash::Allreduce::run(dP_sum_cur, sum_op); } Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape>{}, make_coord(m_block)); if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) { #pragma unroll for (int mi = 0; mi < size(dP_sum); ++mi) { int const row = get<0>(tOcO(_0{}, mi, _0{})); gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0; } } int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM); Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0); Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape>{}, make_coord(m_block)); if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) { gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E); } if constexpr (Clear_dQaccum) { Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); GmemTiledCopyAccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); cute::copy(Copy_Atom, ElementAccum>{}, zero, tdQgdQaccum); } if (params.dq_semaphore != nullptr && thread_idx == 0) { int const num_batch = params.num_batch; int const num_head = get<2>(params.shape_O); params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0; } } }; } // namespace flash ================================================ FILE: hopper/flash_fwd_combine.cu ================================================ // Copyright (c) 2024, Tri Dao. // Splitting the different head dimensions to different files to speed up compilation. #include "flash_fwd_combine_launch_template.h" template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl); ================================================ FILE: hopper/flash_fwd_combine_kernel.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include #include "cutlass/arch/grid_dependency_control.h" #include "seqlen.h" #include "utils.h" namespace flash { using namespace cute; template class FlashAttnFwdCombine { public: // Type Aliases using TileShape_MK = TileShape_MK_; using ArchTag = ArchTag_; static constexpr int kMaxSplits = 1 << kLogMaxSplits_; static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float))); static_assert(AlignmentLSE >= 1); static constexpr int kStages = 4; static_assert(ArchTag::kMinComputeCapability >= 75); static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; static constexpr uint32_t MaxThreadsPerBlock = kNThreads; static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); static constexpr int kBlockK = get<1>(TileShape_MK{}); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemCopyAtom = std::conditional_t< Has_cp_async, cute::Copy_Atom, ElementPartial>, cute::Copy_Atom, ElementPartial> >; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); using GmemTiledCopyAccum = decltype( make_tiled_copy(GmemCopyAtom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 4 vals per load using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 4 vals per load using AlignmentTypeLSE = cute::uint_byte_t(sizeof(float)) * AlignmentLSE>; static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float); static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE"); static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8"); static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8))); static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE"); using GmemLayoutAtomLSE = Layout, Int>, Stride, _1>>; static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0); using GmemCopyAtomLSE = std::conditional_t< Has_cp_async, cute::Copy_Atom, float>, cute::Copy_Atom, float> >; using GmemTiledCopyLSE = decltype( make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{}, Layout>>{})); // Val layout, 4 vals per load // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE"); // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts using SmemLSESwizzle = std::conditional_t< kBlockMSmem == 8, Swizzle<5, 0, 5>, std::conditional_t, Swizzle<3, 2, 3>> >; using SmemLayoutAtomLSE = decltype(composition(SmemLSESwizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); using SmemLayoutO = Layout, Int, Int>, Stride, _1, Int>>; // We want each column (kMaxSplits) to be processed by threads in the same warp. // To reduce the number of shuffles, we want as few threads on the same column as possible. // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column // have have 64 such quads. static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem"); static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem; static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp"); using S2RLayoutAtomLSE = Layout, Int>>; using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom{}, S2RLayoutAtomLSE{}, Layout<_1>{})); using ShapeOPartial = cute::Shape; // (seqlen, d, num_splits, head, batch) using StrideOPartial = cute::Stride; using ShapeLSEPartial = cute::Shape; // (seqlen, num_splits, head, batch) using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch) using ShapeO = cute::Shape; // (seqlen, d, head, batch) using StrideO = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) struct SharedStorage : cute::aligned_struct<128> { cute::array_aligned> smem_lse_partial; cute::array_aligned smem_max_valid_split; cute::array_aligned> smem_o_partial; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; Element* const ptr_O; StrideO const stride_O; float* const ptr_LSE; StrideLSE const stride_LSE; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; // Kernel entry point API struct Params { ElementPartial const* const ptr_O_partial; ShapeOPartial const shape_O_partial; StrideOPartial const stride_O_partial; float const* const ptr_LSE_partial; ShapeLSEPartial const shape_LSE_partial; StrideLSEPartial const stride_LSE_partial; Element* const ptr_O; StrideO const stride_O; float* const ptr_LSE; StrideLSE const stride_LSE; cutlass::FastDivmod seqlen_divmod, head_divmod; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { assert(get<1>(args.shape_LSE_partial) <= kMaxSplits); return { args.ptr_O_partial, args.shape_O_partial, args.stride_O_partial, args.ptr_LSE_partial, args.shape_LSE_partial, args.stride_LSE_partial, args.ptr_O, args.stride_O, args.ptr_LSE, args.stride_LSE, cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)), args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, args.varlen_batch_idx_ptr, args.semaphore_to_reset, }; } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{}); Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape>{}); Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; int const maybe_virtual_batch = blockIdx.z; int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); *params.semaphore_to_reset = 0; } if (num_splits <= 1) { return; } flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; int const offset = seqlen_info.offset; int const seqlen = seqlen_info.seqlen; int max_idx = seqlen * get<2>(params.shape_LSE_partial); if constexpr (Varlen) { if (m_block * kBlockM >= max_idx) { return; } } cutlass::FastDivmod seqlen_divmod_dynamic(seqlen); // Step 1: load LSE_partial from gmem -> smem Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)), select<1, 0, 2, 3>(params.shape_LSE_partial), select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head) Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int>{}); GmemTiledCopyLSE gmem_tiled_copy_LSE; auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx); Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE); // Construct identity layout for sLSE Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m) // Repeat the partitioning with identity layouts Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE); cutlass::arch::wait_on_dependent_grids(); #pragma unroll for (int m = 0; m < size<2>(tLSEcLSE); ++m) { int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { int m_idx, bidh; if constexpr (!Varlen) { bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); } Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh); #pragma unroll for (int s = 0; s < size<1>(tLSEcLSE); ++s) { int si = get<0>(tLSEcLSE(_0{}, s, _0{})); // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);} if (si < num_splits) { cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m)); } else { cute::fill(tLSEsLSE(_, s, m), -INFINITY); } } } else { // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem // cute::fill(tLSEsLSE(_, _, m), -INFINITY); } } if constexpr (Has_cp_async) { cute::cp_async_fence(); } // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2. // We want these async loads to be in flight as we compute the LSE. GmemTiledCopyAccum gmem_tiled_copy_O_partial; auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx); // Construct identity layout for gO Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); Tensor tObidh = make_tensor(make_shape(size<1>(tOcO))); Tensor tOrOptr = make_tensor(make_shape(size<1>(tOcO))); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { int mi = get<0>(tOcO(_0{}, m, _0{})); int idx = m_block * kBlockM + mi; if constexpr (!Varlen) { tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx); } else { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); } tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m)); if (idx >= max_idx) { tObidh[m] = -1; } } Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); if constexpr (!(Is_even_K)) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } } Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); auto load_O_partial = [&] (int split, int stage) { Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { if (tObidh(m) >= 0) { Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout()); Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (Is_even_K || tOpO(k)) { cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k)); } } } } }; for (int s = 0; s < kStages - 1; ++s) { if (s < num_splits) { load_O_partial(s, s); } if constexpr (Has_cp_async) { cute::cp_async_fence(); } } // Step 3: load and transpose LSE_partial from smem -> rmem if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } __syncthreads(); S2RTiledCopyLSE s2r_tiled_copy_LSE; auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx); Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE); Tensor ts2rrLSE = make_fragment_like(ts2rsLSE); cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE); // Step 4: compute the final LSE along the split dimension Tensor lse_sum = make_tensor(make_shape(size<2>(ts2rrLSE))); Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE); // We compute the max valid split for each row to short-circuit the computation later Tensor max_valid_split = make_tensor(make_shape(size<2>(ts2rrLSE))); static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1); #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { float lse_max = ts2rrLSE(_0{}, _0{}, m); #pragma unroll for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); } MaxOp max_op; lse_max = Allreduce::run(lse_max, max_op); int max_valid_idx = -1; #pragma unroll for (int s = 0; s < size<1>(ts2rrLSE); ++s) { if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); } } MaxOp max_int_op; max_valid_split[m] = Allreduce::run(max_valid_idx, max_int_op); float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf float lse_sum_cur = 0.f; #pragma unroll for (int s = 0; s < size<1>(ts2rrLSE); ++s) { float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur); lse_sum_cur += scale; // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);} // ts2rsLSE(_0{}, m, s) = scale; ts2rrLSE(_0{}, s, m) = scale; } SumOp sum_op; lse_sum_cur = Allreduce::run(lse_sum_cur, sum_op); lse_sum(m) = logf(lse_sum_cur) + lse_max; float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur; #pragma unroll for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; } } // Store the scales exp(lse - lse_logsum) back to smem cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); // Store max_valid_split to smem #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } } } // Step 5: store final LSE back to gmem if (k_block == 0) { auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0); #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); int idx = m_block * kBlockM + mi; if (idx < max_idx) { int m_idx, bidh; if constexpr (!Varlen) { bidh = params.seqlen_divmod.divmod(m_idx, idx); } else { bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); } // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); mLSE(m_idx, bidh) = lse_sum(m); } } } } // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O __syncthreads(); int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))]; #pragma unroll for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); } Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor(TileShape_MK{})).layout(); Tensor tOrOpartial = make_fragment_like(tOrOpartial_layout); Tensor tOrO = make_fragment_like(tOrOpartial); clear(tOrO); int stage_load = kStages - 1, stage_compute = 0; #pragma unroll 4 // Already tuned for speed for (int s = 0; s <= thr_max_valid_split; ++s) { Tensor scale = make_tensor(make_shape(size<1>(tOrOpartial))); #pragma unroll for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); } if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); } if constexpr (Has_cp_async) { cute::cp_async_fence(); } stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0; if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait(); } // We don't need __syncthreads() because each thread is just reading its own data from smem cute::copy(Copy_Atom, ElementPartial>{}, tOsOpartial(_, _, _, stage_compute), tOrOpartial); stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0; #pragma unroll for (int m = 0; m < size<1>(tOrOpartial); ++m) { if (tObidh(m) >= 0 && scale(m) > 0.f) { #pragma unroll for (int k = 0; k < size<2>(tOrOpartial); ++k) { if (Is_even_K || tOpO(k)) { Tensor rOpartial = make_tensor_like(tOrOpartial(_, m, k)); flash::convert_type_out(tOrOpartial(_, m, k), rOpartial); #pragma unroll for (int i = 0; i < size<0>(tOrOpartial); ++i) { tOrO(i, m, k) += scale(m) * rOpartial[i]; } } } } } } // Step 7: Write the final O to gmem Tensor rO = make_tensor_like(tOrO); flash::convert_type_out(tOrO, rO); auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); #pragma unroll for (int m = 0; m < size<1>(tOcO); ++m) { if (tObidh(m) >= 0) { #pragma unroll for (int k = 0; k < size<2>(tOcO); ++k) { int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (Is_even_K || tOpO(k)) { cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m))); } } } } } }; } // namespace flash ================================================ FILE: hopper/flash_fwd_combine_launch_template.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80 #include "cutlass/device_kernel.h" // For device_kernel #include "cutlass/kernel_launch.h" // For kernel_launch #include "cuda_check.h" #include "static_switch.h" #include "flash.h" #include "flash_fwd_combine_kernel.h" using namespace cute; template void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial static_cast(params.softmax_lseaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial static_cast(params.o_ptr), {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM); dim3 grid_m(num_blocks_m, num_blocks_k, params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // kernel<<>>(kernel_params); CHECK_CUTLASS(cutlass::kernel_launch(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/)); } template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); ARCH_SWITCH(params.arch, Arch, [&] { BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] { if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. if (params.num_splits <= 16) { run_flash_fwd_combine(params, stream, enable_pdl); return; } } if (params.num_splits <= 32) { run_flash_fwd_combine(params, stream, enable_pdl); } else if (params.num_splits <= 64) { run_flash_fwd_combine(params, stream, enable_pdl); } else if (params.num_splits <= 128) { run_flash_fwd_combine(params, stream, enable_pdl); } else { run_flash_fwd_combine(params, stream, enable_pdl); } }); }); } ================================================ FILE: hopper/flash_fwd_kernel_sm80.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include "seqlen.h" #include "utils.h" #include "softmax.h" namespace flash { using namespace cute; template class FlashAttnFwdSm80 { public: // Type Aliases using CollectiveMainloop = CollectiveMainloop_; using CollectiveEpilogue = CollectiveEpilogue_; static constexpr bool Is_causal = CollectiveMainloop::Is_causal; static constexpr bool Is_local = CollectiveMainloop::Is_local; static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; static constexpr bool PagedKV = CollectiveMainloop::PagedKV; static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; // Epilogue derived types using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 80); using TileScheduler = TileScheduler_; using TileSchedulerArguments = typename flash::TileSchedulerArguments; using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{})); static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1; // Kernel level shared memory storage // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k). static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { union { struct { cute::array padding_; typename CollectiveMainloop::TensorStorage mainloop; }; // We want smem_o to line up with the start of smem_v typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; alignas(16) typename TileScheduler::SharedStorage smem_scheduler; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; // Kernel entry point API struct Params { MainloopParams mainloop{}; EpilogueParams epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerParams scheduler{}; }; // // Methods // // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; return { CollectiveMainloop::to_underlying_arguments(args.mainloop), CollectiveEpilogue::to_underlying_arguments(args.epilogue), hw_info, TileScheduler::to_underlying_arguments(args.scheduler) }; } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int kBlockM = get<0>(TileShape_MNK{}); SharedStorage& shared_storage = *reinterpret_cast(smem_buf); CollectiveMainloop mainloop; CollectiveEpilogue epilogue; TileScheduler scheduler(reinterpret_cast(&shared_storage.smem_scheduler)); // Initialize matmul objects. TiledMma tiled_mma; scheduler.init_consumer(); int warp_idx = cutlass::canonical_warp_idx_sync(); CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = warp_idx == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); int const bidb = get<2>(block_coord); if constexpr (Is_FP8 && !Has_softcap) { int const bidh = get<1>(block_coord); int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; softmax_scale_log2 *= q_descale * k_descale; } flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { __syncthreads(); } } bool tile_valid = mainloop.mma( params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord, shared_storage); scheduler.prefetch_next_work(params.scheduler, work_tile_info); if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma, threadIdx.x, block_coord); } else { // Write 0 to gO and -inf to gLSE. epilogue.store_zero(params.epilogue, threadIdx.x, block_coord); } } } }; } // namespace flash ================================================ FILE: hopper/flash_fwd_kernel_sm90.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/arch/grid_dependency_control.h" #include "seqlen.h" #include "utils.h" #include "softmax.h" namespace flash { using namespace cute; template class FlashAttnFwdSm90 { public: // Type Aliases using CollectiveMainloop = CollectiveMainloop_; using CollectiveEpilogue = CollectiveEpilogue_; static constexpr bool Is_causal = CollectiveMainloop::Is_causal; static constexpr bool Is_local = CollectiveMainloop::Is_local; static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen); static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap; static constexpr bool Varlen = CollectiveMainloop::Varlen; static constexpr bool Split = CollectiveMainloop::Split; static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; static constexpr bool HasQv = CollectiveMainloop::HasQv; static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; using BarrierQ = std::conditional_t; // Epilogue derived types using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert(ArchTag::kMinComputeCapability >= 90); using TileScheduler = TileScheduler_; using TileSchedulerArguments = typename flash::TileSchedulerArguments; using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); /// Register requirement for Load and Math WGs // If we use cp.async to load K and V, we need more registers for the producer WG. static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32); static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160); // If you want to print from the producer warp, you'd need to increase the number of registers // Otherwise you'll get CUDA error. // static constexpr uint32_t LoadRegisterRequirement = 40; // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152; // Kernel level shared memory storage // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v). static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _1> { union { struct { cute::array padding_; typename CollectiveMainloop::TensorStorage mainloop; }; // We want smem_o to line up with the start of smem_v typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; alignas(16) BarrierQ barrier_Qv; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt; alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new; alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new; alignas(16) typename TileScheduler::SharedStorage smem_scheduler; } pipelines; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); // Device side arguments struct Arguments { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerArguments scheduler{}; }; // Kernel entry point API struct Params { MainloopParams mainloop{}; EpilogueParams epilogue{}; cutlass::KernelHardwareInfo hw_info{}; TileSchedulerParams scheduler{}; }; // // Methods // // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; return { CollectiveMainloop::to_underlying_arguments(args.mainloop), CollectiveEpilogue::to_underlying_arguments(args.epilogue), hw_info, TileScheduler::to_underlying_arguments(args.scheduler) }; } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count); } static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt; using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew; using PipelineState = typename CollectiveMainloop::PipelineState; using PipelineParamsK = typename MainloopPipelineK::Params; using PipelineParamsV = typename MainloopPipelineV::Params; using PipelineParamsVt = typename MainloopPipelineVt::Params; using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params; SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); // Issue Tma Descriptor Prefetch from a single thread if (warp_idx == 0 && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; int warp_group_idx = cutlass::canonical_warp_group_idx(); if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); if constexpr (HasQv) { shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); } shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); } // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); PipelineParamsK pipeline_params_k; pipeline_params_k.role = warp_group_idx == 0 ? MainloopPipelineK::ThreadCategory::Producer : MainloopPipelineK::ThreadCategory::Consumer; if constexpr (Use_TMA_KV) { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; } else { pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; pipeline_params_k.producer_arv_count = NumProducerThreads; } static_assert(is_same_v); PipelineParamsVt pipeline_params_vt = pipeline_params_k; if constexpr (Use_TMA_KV && !SameHeadDim) { pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } } else { if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } } MainloopPipelineK pipeline_k = [&] { if constexpr (Use_TMA_KV) { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); } else { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k); } }(); // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); MainloopPipelineV pipeline_v = [&] { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); } else { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); } } else { PipelineParamsV pipeline_params_v; pipeline_params_v.role = warp_group_idx == 0 ? MainloopPipelineV::ThreadCategory::Producer : MainloopPipelineV::ThreadCategory::Consumer; pipeline_params_v.producer_arv_count = NumProducerThreads; pipeline_params_v.consumer_arv_count = NumMmaThreads; return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } }(); // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then // the producer WG will read from pipeline_vt and write to pipeline_v. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers. // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); } else { pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); } }(); PipelineParamsKVNew pipeline_params_kv_new; pipeline_params_kv_new.role = warp_group_idx == 0 ? MainloopPipelineKVNew::ThreadCategory::Producer : MainloopPipelineKVNew::ThreadCategory::Consumer; pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; pipeline_params_kv_new.num_consumers = NumMmaThreads; auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); if constexpr (!SameHeadDim) { pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); CollectiveMainloop mainloop; CollectiveEpilogue epilogue; // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { cute::cluster_arrive_relaxed(); cute::cluster_wait(); } else { __syncthreads(); } TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); if (warp_group_idx == 0) { // Producer cutlass::arch::warpgroup_reg_dealloc(); // The pipelines for AppendKV and main attention are different, since e.g. main attention // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load // KV_new. Since the pipeline states are different, we have to manually sync to make // sure the two pipelines don't race when accessing smem_k and smem_v. PipelineState smem_pipe_write = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_new = cutlass::make_producer_start_state(); int work_idx = 0; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; if constexpr (SingleProducerWarp) { if (warp_idx_in_warpgroup != 0) { return; } } if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); } cutlass::arch::wait_on_dependent_grids(); // Load Q, K, V for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work(params.scheduler) : scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work(params.scheduler, work_tile_info) : scheduler.template get_next_work(params.scheduler, work_tile_info)) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); SeqlenInfo_t seqlen_info{ get<2>(block_coord) /*bidb*/, get<0>(params.mainloop.shape_Q), !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.load_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx); if (tile_new_valid) { // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); } cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); // if (threadIdx.x == 0) { printf("Producer: After sync\n"); } } } auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() { scheduler.prefetch_next_work(params.scheduler, work_tile_info); }; // pipeline_vt won't be used if we don't need to transpose V. mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx); } mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx); } else { // Consumer cutlass::arch::warpgroup_reg_alloc(); // Initialize matmul objects. TiledMmaPV tiled_mma_pv; PipelineState smem_pipe_read; PipelineState smem_pipe_read_new; // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v // (like in Cutlass's gemm) because the read and release pipeline states are always the same. scheduler.init_consumer(); mainloop.mma_init(); int work_idx = 0; CUTLASS_PRAGMA_NO_UNROLL for (auto work_tile_info = scheduler.template get_initial_work(params.scheduler); work_tile_info.is_valid(params.scheduler); // get_next_work will be called before the epilogue ) { auto block_coord = work_tile_info.get_block_coord(params.scheduler); int const bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{ bidb, get<0>(params.mainloop.shape_Q), !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable), get<0>(params.mainloop.shape_K_new), params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new, params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k, params.mainloop.seqlens_rotary }; if constexpr (AppendKV) { bool tile_new_valid = mainloop.store_kv_new( params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new, threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord); if (tile_new_valid) { // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); } // We need this sync so that the gmem write from the consumers is visible to the producer // that might do TMA read after that. asm volatile ("fence.proxy.async.global;"); cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::AppendKV) /*id*/); // arrive is enough, we don't need sync. The producer will sync, which means // after that sync we're guaranteed that the AppendKV pipeline have finished // loading and consumer smem_k and smem_v. // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } // If there's tanh softcap, the scaling will be done before tanh. float softmax_scale_log2 = params.mainloop.softmax_scale_log2; if constexpr (Is_FP8 && !Has_softcap) { int const bidh = get<1>(block_coord); int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh; float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)]; float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; softmax_scale_log2 *= q_descale * k_descale; } flash::Softmax softmax(softmax_scale_log2); // Attention output (GEMM-II) accumulator. Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); bool tile_valid; if constexpr (!LargeHeadDimV) { tile_valid = mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { // mma_pv might not compile if !LargeHeadDimV if (warp_group_idx == 1) { tile_valid = mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { tile_valid = mainloop.mma_pv( params.mainloop, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } // Do this here before the epilogue so that the next tile is ready to go. work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info); if constexpr (Split && Varlen) { if (!work_tile_info.is_valid(params.scheduler)) { // Last tile cutlass::arch::launch_dependent_grids(); } } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord); } } epilogue.store_tail(); } } }; } // namespace flash ================================================ FILE: hopper/flash_fwd_launch_template.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" // For device_kernel #include #include "cutlass/cluster_launch.hpp" #include "cutlass/kernel_launch.h" #include "cuda_check.h" #include "static_switch.h" #include "flash.h" #include "tile_size.h" #include "tile_scheduler.hpp" #include "flash_fwd_kernel_sm90.h" #include "flash_fwd_kernel_sm80.h" #include "mainloop_fwd_sm90_tma_gmma_ws.hpp" #include "mainloop_fwd_sm80.hpp" #include "epilogue_fwd.hpp" using namespace cute; template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor; using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap); static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; static constexpr bool LPT = Is_causal || Is_local; static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> > >; using SchedulerSingleTile = flash::SingleTileScheduler; // If Split then we probably don't have enough work for PersistentScheduler to be useful. // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better // since we'll avoid launching a bunch of thread blocks that immediately exit. // On Sm80, noncausal persistent seems a bit slower. static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split)); using Scheduler = std::conditional_t; using AttnKernel = std::conditional_t< Arch >= 90, flash::enable_sm90>, flash::enable_sm80_to_sm89> >; bool const is_varlen_q = params.cu_seqlens_q; bool const is_varlen_k = params.cu_seqlens_k; bool const is_varlen_k_new = params.cu_seqlens_knew; int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q; int batch_q = !is_varlen_q ? params.b : 1; int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1; typename CollectiveMainloop::StrideV v_strides = cute::conditional_return( make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0), make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0)); typename CollectiveMainloop::Arguments mainloop_args { static_cast(params.q_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_Q {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q static_cast(params.k_ptr), {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size, params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), params.dv, // headdim_v v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new static_cast(params.vnew_ptr), {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new static_cast(params.qv_ptr), {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv static_cast(params.rotary_cos_ptr), {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter {params.rotary_dim / 2, _1{}}, // stride_rotary_cos static_cast(params.rotary_sin_ptr), {params.rotary_dim / 2, _1{}}, // stride_rotary_sin params.is_rotary_interleaved, params.page_table, // if page_size is not set, avoid dividing by zero {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table {params.page_table_batch_stride, _1{}}, // stride_page_table params.scale_softmax, params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr, {params.q_descale_batch_stride, params.q_descale_head_stride}, {params.k_descale_batch_stride, params.k_descale_head_stride}, {params.v_descale_batch_stride, params.v_descale_head_stride}, params.window_size_left, params.window_size_right, params.attention_chunk, params.softcap, params.num_splits, params.kv_batch_idx, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, params.leftpad_k, params.seqlens_rotary }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.o_ptr), {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O static_cast(params.oaccum_ptr), {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial static_cast(params.softmax_lse_ptr), {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE static_cast(params.softmax_lseaccum_ptr), {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial params.h_k, params.cu_seqlens_q, params.seqused_q }; int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{})); num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{})); typename flash::TileSchedulerArguments scheduler_args { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.num_m_blocks_ptr, params.varlen_batch_idx_ptr, params.num_nheads_in_l2_ptr }; if (Varlen && !params.skip_scheduler_metadata_computation) { prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } int device; CHECK_CUDA(cudaGetDevice(&device)); typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({ mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args }); dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); dim3 block_dims = AttnKernel::get_block_shape(); int smem_size = AttnKernel::SharedStorageSize; // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q)); // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)); // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)); // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); // Get the ptr to kernel function. if constexpr (size(ClusterShape{}) > 1) { void const* kernel = (void const*) cutlass::device_kernel; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; CHECK_CUTLASS(cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params)); } else { auto kernel = cutlass::device_kernel; if (smem_size >= 48 * 1024) { CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // kernel<<>>(kernel_params); CHECK_CUTLASS(cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/)); } } template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; using T_out = std::conditional_t; CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] { VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; run_flash_fwd(params, stream); }); }); }); }); }); }); } ================================================ FILE: hopper/flash_prepare_scheduler.cu ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #include #include "cutlass/fast_math.h" #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" #include "cutlass/arch/grid_dependency_control.h" #include "flash.h" #include "static_switch.h" namespace flash { // Sort in descending order template struct PrepareSortOp { __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) { return lhs > rhs; } }; template <> struct PrepareSortOp { __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { return lhs.x > rhs.x; } }; template <> struct PrepareSortOp { __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { return lhs.x > rhs.x; } }; template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr, int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, int* const varlen_batch_idx_ptr, // int* const num_n_blocks_ptr, int* const num_nheads_in_l2_ptr, bool enable_pdl, bool is_causal, bool packgqa, int max_kvblocks_in_l2) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; static constexpr int BLOCK_DIM_X = NumWarps * 32; static constexpr int ITEMS_PER_THREAD = 1; static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); using BlockMergeSort = cub::BlockMergeSort; __shared__ int total_blocks_smem[kSmemSize]; // Allocate shared memory for BlockMergeSort operations __shared__ typename BlockMergeSort::TempStorage temp_storage; if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } __syncthreads(); if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; } int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; } else if (cu_seqlens_q) { int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0; int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); seqlen = next_cu_seqlen - cur_cu_seqlen; } else { seqlen = seqlen_q_static; } if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; auto get_num_n_blocks = [&](int batch_idx) { int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; int seqlen; if (seqused_k) { seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0; } else if (cu_seqlens_k) { int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0; int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); seqlen = next_cu_seqlen - cur_cu_seqlen; } else { seqlen = seqlen_k_static; } int seqlen_new; if (cu_seqlens_k_new) { int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0; int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1); seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new; } else { seqlen_new = seqlen_k_new_static; } // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); } seqlen = seqlen - leftpad_k + seqlen_new; return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0; }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; int batch_cta_idx_offset = int(blockIdx.x) * 992; int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; int batch_idx = lane + bidb_start; int num_m_blocks = get_num_m_blocks(batch_idx); int num_n_blocks = get_num_n_blocks(batch_idx); auto get_nheads_in_l2 = [&](int n_blocks) { int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 : 1; if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } return min(nheads_in_l2, num_head); }; int num_splits_dynamic; if (int(gridDim.x) > 1 || num_splits_static == 1) { // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) num_splits_dynamic = 1; } else { int total_blocks = num_m_blocks * num_n_blocks; // Warp sum #pragma unroll for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); } if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } __syncthreads(); total_blocks = total_blocks_smem[0]; // 10% margin int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); // num_n_blocks per work tile for the batch num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); } if constexpr (Sort) { if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { num_n_blocks = INT_MIN; // sort last } else if (is_causal) { // sort by shortest member to process num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; } int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); // if (threadIdx.x == 0) { // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); // } __syncthreads(); // Sort batches by num_n_blocks in descending order BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); // if (threadIdx.x == 0) { // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); // } __syncthreads(); if (is_causal) { // reset value to num_n_blocks batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); } // When sorting, we re-index some metadata by 'virtual batch index' // and also store the vbidx -> bidx mapping. // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx batch_idx = batch_cta_idx_offset + threadIdx.x; if (batch_idx < num_batch && threadIdx.x < 992) { // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } num_m_blocks_ptr[batch_idx] = batch_coords[0].y; num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; } } else { if (batch_idx < num_batch && lane < kNumBatchPerWarp) { // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; num_m_blocks_ptr[batch_idx] = num_m_blocks; // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); } } } } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl) { int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 int num_ctas = cutlass::ceil_div(params.b, 31 * 32); // int const size_l2 = 50 * 1024 * 1024; // 50 MB int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice int const element_size = params.is_e4m3 ? 1 : 2; int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { NUM_WARP_SWITCH(num_warps, NumWarps, [&] { flash::prepare_varlen_num_blocks_kernel<<>>( params.seqlen_q, params.seqlen_k, params.seqlen_knew, params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, params.seqused_q, params.seqused_k, params.leftpad_k, params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), params.tile_count_semaphore, params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, // params.num_n_blocks_ptr, params.num_nheads_in_l2_ptr, enable_pdl, params.is_causal, packgqa, max_kvblocks_in_l2); }); }); } ================================================ FILE: hopper/generate_kernels.py ================================================ # Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602 # This file is run to generate the kernel instantiations for the flash_attn kernels # They are written to several files in order to speed up compilation import argparse import itertools from collections import namedtuple from dataclasses import dataclass from pathlib import Path from typing import List, Optional KERNEL_BATCH = namedtuple("Kernel", ["template", "filename"]) DTYPE_MAP = { "fp16": "cutlass::half_t", "bf16": "cutlass::bfloat16_t", "e4m3": "cutlass::float_e4m3_t", } DTYPE_MAP_FWD_SM8x = { "fp16": "cutlass::half_t", "bf16": "cutlass::bfloat16_t", } DTYPE_MAP_BWD = { "fp16": "cutlass::half_t", "bf16": "cutlass::bfloat16_t", } SM = [80, 90] # Sm kernels support up to HEAD_DIMENSIONS = [64, 96, 128, 192, 256] PAGEDKV = [False, True] SPLIT = [False, True] SOFTCAP = [False, True] PACKGQA = [False, True] KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif """ KERNEL_IMPL_TEMPLATE_FWD_SM8x = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif """ KERNEL_IMPL_TEMPLATE_BWD_SM90 = """#include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} template<> void run_mha_bwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<{ARCH}, {DTYPE}, {SOFTCAP}>(params, stream); }} #endif """ KERNEL_IMPL_TEMPLATE_BWD_SM8x = """#include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} template<> void run_mha_bwd_<80, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<80, {DTYPE}, {SOFTCAP}>(params, stream); }} template<> void run_mha_bwd_<86, {DTYPE}, {HEAD_DIM}, {SOFTCAP}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ run_mha_bwd_hdim{HEAD_DIM}<86, {DTYPE}, {SOFTCAP}>(params, stream); }} #endif #endif """ @dataclass class Kernel: sm: int dtype: str head_dim: int head_dim_v: int split: bool paged_kv: bool softcap: bool packgqa: bool direction: str @property def template(self) -> str: if self.direction == "fwd": if self.sm == 90: # Always enable PackGQA for PagedKV or Split to reduce compilation packgqa = self.packgqa or self.paged_kv or self.split return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() ) else: # Always enable PackGQA for Sm8x to reduce compilation return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() ) elif self.direction == "bwd": if self.sm == 90: return KERNEL_IMPL_TEMPLATE_BWD_SM90.format( ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) else: return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) @property def filename(self) -> str: return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): # We always enable PackGQA for Sm8x or PagedKV or Split # so we should just pass in packgqa=False to avoid the `_packgqa` in the filename. if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): continue if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=256, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): if sm < 90: continue # Same hdim and hdimv kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) # Different hdim and hdimv kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: for dtype, head_dim, split, paged_kv, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, PACKGQA, SM): if sm >= 90: continue kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.head_dim == head_dim and k.split == split and k.paged_kv == paged_kv and k.packgqa == packgqa and k.sm == sm] if len(kernels) > 0: filename = f"flash_fwd_hdim{head_dim}_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}_softcapall{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) # Bwd for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): if sm < 90: continue kernels = [k for k in kernels_all if k.direction == "bwd" and k.dtype == dtype and k.head_dim == head_dim and k.sm == sm] if len(kernels) > 0: filename = f"flash_bwd_hdim{head_dim}_{dtype}_softcapall_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: prelude = """// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py"\n """ (autogen_dir / kernel.filename).write_text(prelude + kernel.template) def main(output_dir: Optional[str]) -> None: output_dir = Path(output_dir) if output_dir is not None else Path(__file__).parent output_dir.mkdir(parents=True, exist_ok=True) kernels_all = list(get_all_kernels()) for kernel in kernels_all: write_kernel(kernel, output_dir) for kernel in batch_hdim(kernels_all): write_kernel(kernel, output_dir) for kernel in batch_softcap(kernels_all): write_kernel(kernel, output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate_kernels", description="Generate the flash_attention kernels template instantiations", ) # Set an optional output directory parser.add_argument( "-o", "--output_dir", default="instantiations", required=False, help="Where to generate the kernels " " will default to the current directory ", ) args = parser.parse_args() main(args.output_dir) ================================================ FILE: hopper/heuristics.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { // If varlen, we don't actually know seqlen_q but only max_seqlen_q. if (varlen_q) return true; // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM auto round_up = [](int a, int b) { return (a + b - 1) / b * b; }; float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM)); float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM)); return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency; }; // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split // However, in the case of super long seqlen where each head of KV doesn't even fit into // L2 (we assume that L2 size is 50MB), we want to split. if (total_mblocks >= 0.8f * num_SMs) { int const size_l2 = 50 * 1024 * 1024; // Only split if there are enough queries to go over the KV at least twice // Don't split if causal if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) { return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits); } else { return 1; } } // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512. if (num_n_blocks <= 4) { return 1; } max_splits = std::min({max_splits, num_SMs, num_n_blocks}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); for (int num_splits = 1; num_splits <= max_splits; num_splits++) { float n_waves = float(total_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim128_bf16_sm90.cu" #include "flash_bwd_hdim128_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template<> void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim128_fp16_sm90.cu" #include "flash_bwd_hdim128_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim192_bf16_sm90.cu" #include "flash_bwd_hdim192_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template<> void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim192_fp16_sm90.cu" #include "flash_bwd_hdim192_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim256_bf16_sm90.cu" #include "flash_bwd_hdim256_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<90, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<90, cutlass::half_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<80, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<80, cutlass::half_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<86, cutlass::half_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template<> void run_mha_bwd_<90, cutlass::half_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim256<90, cutlass::half_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim256_fp16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim256_fp16_sm90.cu" #include "flash_bwd_hdim256_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<80, cutlass::bfloat16_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<86, cutlass::bfloat16_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<90, cutlass::bfloat16_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<80, cutlass::bfloat16_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<86, cutlass::bfloat16_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<90, cutlass::bfloat16_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_bf16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim64_bf16_sm90.cu" #include "flash_bwd_hdim64_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<80, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<80, cutlass::half_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<86, cutlass::half_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<90, cutlass::half_t, 64, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<90, cutlass::half_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<80, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<80, cutlass::half_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<86, cutlass::half_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template<> void run_mha_bwd_<90, cutlass::half_t, 64, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim64<90, cutlass::half_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim64_fp16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim64_fp16_sm90.cu" #include "flash_bwd_hdim64_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<80, cutlass::bfloat16_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<86, cutlass::bfloat16_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<90, cutlass::bfloat16_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<80, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<80, cutlass::bfloat16_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<86, cutlass::bfloat16_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<90, cutlass::bfloat16_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<90, cutlass::bfloat16_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_bf16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim96_bf16_sm90.cu" #include "flash_bwd_hdim96_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<80, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<80, cutlass::half_t, false>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<86, cutlass::half_t, false>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<90, cutlass::half_t, 96, false>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<90, cutlass::half_t, false>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<80, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<80, cutlass::half_t, true>(params, stream); } template<> void run_mha_bwd_<86, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<86, cutlass::half_t, true>(params, stream); } #endif #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template<> void run_mha_bwd_<90, cutlass::half_t, 96, true>(Flash_bwd_params ¶ms, cudaStream_t stream) { run_mha_bwd_hdim96<90, cutlass::half_t, true>(params, stream); } #endif ================================================ FILE: hopper/instantiations/flash_bwd_hdim96_fp16_softcapall_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_bwd_hdim96_fp16_sm90.cu" #include "flash_bwd_hdim96_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_bf16_paged_sm80.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_bf16_paged_split_sm80.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_sm100.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<100, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_bf16_sm80.cu" #include "flash_fwd_hdim128_bf16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_bf16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_bf16_split_sm80.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_fp16_paged_sm80.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_fp16_paged_split_sm80.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_fp16_sm80.cu" #include "flash_fwd_hdim128_fp16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim128_fp16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim128_fp16_split_sm80.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_bf16_paged_sm80.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_bf16_paged_split_sm80.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_bf16_sm80.cu" #include "flash_fwd_hdim192_bf16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_bf16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_bf16_split_sm80.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_fp16_paged_sm80.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_fp16_paged_split_sm80.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_fp16_sm80.cu" #include "flash_fwd_hdim192_fp16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim192_fp16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_fp16_split_sm80.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_bf16_paged_sm80.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_bf16_paged_split_sm80.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_bf16_sm80.cu" #include "flash_fwd_hdim256_bf16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_bf16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_bf16_split_sm80.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_fp16_paged_sm80.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_fp16_paged_split_sm80.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_fp16_sm80.cu" #include "flash_fwd_hdim256_fp16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim256_fp16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim256_fp16_split_sm80.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_sm80.cu" #include "flash_fwd_hdim64_bf16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_split_sm80.cu" #include "flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_sm80.cu" #include "flash_fwd_hdim64_bf16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_bf16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_split_sm80.cu" #include "flash_fwd_hdim64_bf16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_sm80.cu" #include "flash_fwd_hdim64_fp16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_split_sm80.cu" #include "flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_sm80.cu" #include "flash_fwd_hdim64_fp16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim64_fp16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_split_sm80.cu" #include "flash_fwd_hdim64_fp16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_bf16_paged_sm80.cu" #include "flash_fwd_hdim96_bf16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_bf16_paged_split_sm80.cu" #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_bf16_sm80.cu" #include "flash_fwd_hdim96_bf16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_bf16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_bf16_split_sm80.cu" #include "flash_fwd_hdim96_bf16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_fp16_paged_sm80.cu" #include "flash_fwd_hdim96_fp16_paged_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_fp16_paged_split_sm80.cu" #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_fp16_sm80.cu" #include "flash_fwd_hdim96_fp16_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif ================================================ FILE: hopper/instantiations/flash_fwd_hdim96_fp16_split_softcapall_sm80.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim96_fp16_split_sm80.cu" #include "flash_fwd_hdim96_fp16_split_softcap_sm80.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_sm90.cu" #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" #include "flash_fwd_hdim256_bf16_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_split_sm90.cu" #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" #include "flash_fwd_hdim256_bf16_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_paged_sm90.cu" #include "flash_fwd_hdim96_e4m3_paged_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_sm90.cu" #include "flash_fwd_hdim96_e4m3_sm90.cu" #include "flash_fwd_hdim128_e4m3_sm90.cu" #include "flash_fwd_hdim192_e4m3_sm90.cu" #include "flash_fwd_hdim256_e4m3_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_split_sm90.cu" #include "flash_fwd_hdim96_e4m3_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_sm90.cu" #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" #include "flash_fwd_hdim256_fp16_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_split_sm90.cu" #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" #include "flash_fwd_hdim256_fp16_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_paged_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_sm90.cu" #include "flash_fwd_hdim64_512_bf16_sm90.cu" #include "flash_fwd_hdim192_128_bf16_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_split_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_paged_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_sm90.cu" #include "flash_fwd_hdim64_512_fp16_sm90.cu" #include "flash_fwd_hdim192_128_fp16_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_split_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_sm90.cu" ================================================ FILE: hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu ================================================ // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. // Splitting the different template instantiations to different files to speed up compilation. // This file is auto-generated. See "generate_kernels.py" #include "flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" ================================================ FILE: hopper/mainloop_bwd_sm80.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include "cute/tensor.hpp" #include "seqlen.h" #include "mask.h" #include "mask.h" #include "softmax.h" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveMainloopBwdSm80 { static constexpr int kStages = Stages; static constexpr int kStages_dO = Stages_dO; static_assert(kStages >= kStages_dO); using TileShape_MNK = TileShape_MNK_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; static constexpr int NumMmaWarps = NumMmaWarpGroups * cutlass::NumWarpsPerWarpGroup; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static constexpr bool Q_dO_same_stages = kStages == kStages_dO; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using SeqlenInfo_t = flash::SeqlenInfoQK; using BlockMN_t = flash::BlockMN; static_assert(ArchTag::kMinComputeCapability >= 80); static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; static constexpr int NumMmaThreads = NumMmaWarps * cutlass::NumThreadsPerWarp; static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler using MMA_Atom_Arch = std::conditional_t< ArchTag::kMinComputeCapability >= 80, std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >, MMA_Atom >; static_assert(NumMmaWarps % AtomLayoutMSdP == 0); static_assert(NumMmaWarps % AtomLayoutNdKV == 0); static_assert(NumMmaWarps % AtomLayoutMdQ == 0); static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarps && SdP_swapAB && !dKV_swapAB; static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarps && AtomLayoutMdQ == NumMmaWarps && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; static constexpr bool MmaSdPEvenN = ((!SdP_swapAB ? kBlockN : kBlockM) / size<1>(AtomLayoutSdP{})) % 16 == 0; using TiledMmaSdP = TiledMMA< MMA_Atom_Arch, AtomLayoutSdP, Tile(AtomLayoutSdP{}))>, Int<(MmaSdPEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutSdP{}))>, _16>>; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; static constexpr bool MmadKVEvenN = ((!dKV_swapAB ? kHeadDim : kBlockN) / size<1>(AtomLayoutdKV{})) % 16 == 0; using TiledMmadKV = TiledMMA< MMA_Atom_Arch, AtomLayoutdKV, Tile(AtomLayoutdKV{}))>, Int<(MmadKVEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdKV{}))>, _16>>; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; static constexpr bool MmadQEvenN = ((!dQ_swapAB ? kHeadDim : kBlockM) / size<1>(AtomLayoutdQ{})) % 16 == 0; using TiledMmadQ = TiledMMA< MMA_Atom_Arch, AtomLayoutdQ, Tile(AtomLayoutdQ{}))>, Int<(MmadQEvenN ? 16 : 8) * CUTE_STATIC_V(size<1>(AtomLayoutdQ{}))>, _16>>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. static constexpr int kBytePerRow = kHeadDim * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension // changes the layout. using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomKV = decltype( composition(Swizzle{}, // TODO: FA2 has a slightly different layout, does it matter? Layout>, Stride, _1>>{})); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomKV{}, select<1, 2>(TileShape_MNK{}))); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = kBlockN % 64 == 0 ? 64 : (kBlockN % 32 == 0 ? 32 : 16); static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, // it's still a valid smem address. using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; using SmemLayoutLSEMma = std::conditional_t< SdP_swapAB, cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> >; // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutPdSt = decltype(cute::composition(SmemLayoutPdS{}, make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, _1{})))); // Thread layout, 256 or 384 threads per row using R2SLayoutAtomdQaccum = Layout>>; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 1 vals per store using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; // For the case where the N dimension of MmaSdP is divisible by 8 but not by 16 using SmemCopyAtomHalf = Copy_Atom; // For the case where the N dimension of MmadQ is divisible by 8 but not by 16 using SmemCopyAtomTransposedHalf = Copy_Atom; // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. // If PdS_major is MN, then we need to "transpose" the write. // TODO: check this write using R2SCopyAtomPdS = Copy_Atom, Element>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using GmemCopyStruct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL_ZFILL, AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemCopyAtom = Copy_Atom; static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopyQKV = decltype( make_tiled_copy(GmemCopyAtom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per read using GmemCopyAtomLSE = Copy_Atom; using GmemLayoutAtomLSE = Layout>>; using GmemTiledCopyLSE = decltype(make_tiled_copy(GmemCopyAtomLSE{}, GmemLayoutAtomLSE{}, Layout>{})); // Val layout, 4 vals per store // So that we don't have to check if we overshot kBlockM when we load Q // static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // These are tuned for speed. They don't affect correctness. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. // For hdim 192, separating masking iterations results in register spills. // static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; static constexpr bool SeparateMaskingIterations = false; // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep // statistic for 2 rows. // static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; // static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; static constexpr bool ShuffleLSE = SdP_swapAB && false; static constexpr bool ShuffledPsum = SdP_swapAB && false; static constexpr bool Share_QV_Smem = V_in_regs; using SmemP_t = std::conditional_t, cute::array_aligned>>; struct TensorStorageSharedQV : cute::aligned_struct<128> { cute::array_aligned> smem_k; union { cute::array_aligned> smem_v; cute::array_aligned> smem_q; }; cute::array_aligned> smem_do; cute::array_aligned, 128> smem_lse; cute::array_aligned, 128> smem_dpsum; SmemP_t smem_p; cute::array_aligned> smem_ds; }; struct TensorStorageSeparateQV : cute::aligned_struct<128> { cute::array_aligned> smem_k; cute::array_aligned> smem_v; cute::array_aligned> smem_q; cute::array_aligned> smem_do; cute::array_aligned, 128> smem_lse; cute::array_aligned, 128> smem_dpsum; SmemP_t smem_p; cute::array_aligned> smem_ds; }; using TensorStorage = std::conditional_t; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQKV const stride_Q; Element const* const ptr_K; ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; float const* const ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; }; // Device side kernel params struct Params { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQKV const stride_Q; Element const* const ptr_K; ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; cutlass::FastDivmod qhead_per_khead_divmod; float const* const ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; int const window_size_left, window_size_right; cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int *const dq_semaphore; int const *const cu_seqlens_q = nullptr; int const *const cu_seqlens_k = nullptr; int const *const seqused_q = nullptr; int const *const seqused_k = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } // Avoid dividing by zero cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). // In the backward, we need to multiply by // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. return {args.ptr_Q, args.shape_Q, args.stride_Q, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.shape_V, args.stride_V, args.ptr_dO, args.shape_dO, args.stride_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; } template CUTLASS_DEVICE bool mma(Params const& params, FrgTensordKV& tdKrdK, FrgTensordKV& tdVrdV, int thread_idx, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); int n_block = get<0>(block_coord); int bidh = get<1>(block_coord); int bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto m_block_min_max = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); int const m_block_min = get<0>(m_block_min_max); int const m_block_max = get<1>(m_block_min_max); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation GmemTiledCopyLSE gmem_tiled_copy_lse; auto gmem_thr_copy_lse = gmem_tiled_copy_lse.get_thread_slice(thread_idx); R2STiledCopydQaccum r2s_tiled_copy_dQaccum; auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tdOgdO = gmem_thr_copy_QKV.partition_S(gdO); Tensor tdOsdO = gmem_thr_copy_QKV.partition_D(sdO); Tensor tLSEgLSE = gmem_thr_copy_lse.partition_S(gLSE); Tensor tLSEsLSE = gmem_thr_copy_lse.partition_D(sLSE); Tensor tLSEgdPsum = gmem_thr_copy_lse.partition_S(gdPsum); Tensor tLSEsdPsum = gmem_thr_copy_lse.partition_D(sdPsum); // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } TiledMmaSdP tiled_mma_SdP; TiledMmadKV tiled_mma_dKV; TiledMmadQ tiled_mma_dQ; auto thr_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); auto thr_mma_dKV = tiled_mma_dKV.get_thread_slice(thread_idx); auto thr_mma_dQ = tiled_mma_dQ.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, // because some partition_fragment_A/B don't compile. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function Tensor tdPrV = mma_partition_fragment_AB(thr_mma_SdP, sV); // Copy Atom retiling auto smem_copy_atom_SdP_B = cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}); auto smem_tiled_copy_QdO = cute::conditional_return(make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP), make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP)); auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(thread_idx); Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); auto smem_tiled_copy_KV = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_SdP_B, tiled_mma_SdP), make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_SdP)); auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(thread_idx); Tensor tSsK = smem_thr_copy_KV.partition_S(sK); Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); auto r2s_tiled_copy_PdS = make_tiled_copy_C(R2SCopyAtomPdS{}, tiled_mma_SdP); auto r2s_thr_copy_PdS = r2s_tiled_copy_PdS.get_thread_slice(thread_idx); Tensor tPsP = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sP, sPt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = r2s_thr_copy_PdS.partition_D(cute::conditional_return(sdS, sdSt)); // ((Atom,AtomNum),PIPE_M,PIPE_N) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(r2s_thr_copy_PdS); print(sP); printf("\n"); print(sPt); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } auto smem_copy_atom_dKV_B = cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}); auto smem_tiled_copy_PdSt = cute::conditional_return(make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV), make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV)); auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(thread_idx); Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); auto smem_tiled_copy_QdOt = cute::conditional_return(make_tiled_copy_B(smem_copy_atom_dKV_B, tiled_mma_dKV), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dKV)); auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(thread_idx); Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); auto smem_tiled_copy_dS = cute::conditional_return( make_tiled_copy_A(SmemCopyAtom{}, tiled_mma_dQ), make_tiled_copy_B(cute::conditional_return(SmemCopyAtom{}, SmemCopyAtomHalf{}), tiled_mma_dQ)); auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(thread_idx); Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); auto smem_tiled_copy_Kt = cute::conditional_return( make_tiled_copy_B(cute::conditional_return(SmemCopyAtomTransposed{}, SmemCopyAtomTransposedHalf{}), tiled_mma_dQ), make_tiled_copy_A(SmemCopyAtomTransposed{}, tiled_mma_dQ)); auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(thread_idx); Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); // thr_mma_SdP.partition_C(sLSEMma) has shape (MMA=4, MMA_M, MMA_N, PIPE), we only take the col indices // or row indices, depending on whether SdP_swapAB. Tensor tSsLSEMma = logical_divide(thr_mma_SdP.partition_C(sLSEMma), Shape<_2>{}); // (2, 2, MMA_M, MMA_N, PIPE) Tensor tSsLSE = group_modes<0, 2>(cute::conditional_return( tSsLSEMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) tSsLSEMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) Tensor tSsdPsumMma = logical_divide(thr_mma_SdP.partition_C(sdPsumMma), Shape<_2>{}); Tensor tSsdPsum = group_modes<0, 2>(cute::conditional_return( tSsdPsumMma(make_coord(_0{}, _), _, _0{}, _), // (2, MMA_M, PIPE) tSsdPsumMma(make_coord(_, _0{}), _0{}, _, _))); // (2, MMA_N, PIPE) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } // If we want to split the stats among the 8 threads that share the same rows. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tSsLSE))::value, 8); // Predicates Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOsdO))); #pragma unroll for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); } int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.attention_chunk_divmod, params.qhead_per_khead_divmod ); { Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); // Predicates Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } #pragma unroll for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); } // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; // static_assert(EvenN); // It simplifies the loading of K and V // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. // int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN // ? seqlen_info.seqlen_k - n_block * kBlockN // : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN)); // // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockN dimension // flash::copy( // gmem_tiled_copy_QKV, tVgV, tVsV, t0KVcKV, tKVpKV, seqlenk_row_limit); int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); } } } if constexpr (V_in_regs) { flash::cp_async_fence(); } // flash::copy( // gmem_tiled_copy_QKV, tKgK, tKsK, t0KVcKV, tKVpKV, seqlenk_row_limit); #pragma unroll for (int m = 0; m < size<1>(tKsK); ++m) { if (EvenN || m < size<1>(tKsK) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); } } } flash::cp_async_fence(); } if constexpr (V_in_regs) { flash::cp_async_wait<1>(); __syncthreads(); Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV); Tensor tdPsV_copy_view = smem_thr_copy_KV.partition_S(sV); cute::copy(smem_tiled_copy_KV, tdPsV_copy_view, tdPrV_copy_view); __syncthreads(); // Sync to avoid loading Q to smem_q, which overlaps with smem_v } // Do we need bound check to make sure the row doesn't go above kBlockM static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; auto load_Q_LSE = [&] (int const m_block, int const smem_pipe_write) { // if (cute::thread0()) { printf("Inside load_Q_LSE, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } Tensor tQsQ_cur = tQsQ(_, _, _, smem_pipe_write); Tensor tQgQ_cur = tQgQ(_, _, _, m_block); // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM // ? seqlen_info.seqlen_q - m_block * kBlockM // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); // Need Clear_OOB_MN to be true here since the gemm will sum over the kBlockM dimension // flash::copy( // gmem_tiled_copy_QKV, tQgQ(_, _, _, m_block), tQsQ_cur, t0QcQ, tQpQ, seqlenq_row_limit); int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tQsQ); ++m) { // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if (EvenM || m < size<1>(tQsQ) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; #pragma unroll for (int k = 0; k < size<2>(tQsQ); ++k) { cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tQgQ_cur(_, m, k), tQsQ_cur(_, m, k)); } } } Tensor tLSEgLSE_cur = tLSEgLSE(_, _, m_block); Tensor tLSEsLSE_cur = tLSEsLSE(_, _, smem_pipe_write); // We made sure LSE length is padded so we read `kBlockM` elements so that all // elements in sLSE are filled. Without this we might have uninitialized sLSE values. #pragma unroll for (int m = 0; m < size<1>(tLSEsLSE); ++m) { if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { cute::copy(gmem_tiled_copy_lse, tLSEgLSE_cur(_, m), tLSEsLSE_cur(_, m)); } } }; auto load_dO_dPsum = [&] (int const m_block, int const smem_pipe_write) { // if (cute::thread0()) { printf("Inside load_dO_dPsum, m_block = %d, smem_pipe_write = %d\n", m_block, smem_pipe_write); } Tensor tdOsdO_cur = tdOsdO(_, _, _, smem_pipe_write); Tensor tdOgdO_cur = tdOgdO(_, _, _, m_block); // int const seqlenq_row_limit = -int(get<0>(tQcQ(_0{}, _0{}, _0{}))) + (EvenM // ? seqlen_info.seqlen_q - m_block * kBlockM // : std::min(seqlen_info.seqlen_q - m_block * kBlockM, kBlockM)); // flash::copy( // gmem_tiled_copy_QKV, tdOgdO(_, _, _, m_block), tdOsdO_cur, t0QcQ, tQpQ, seqlenq_row_limit); int const seqlenq_row_limit = seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tdOsdO); ++m) { // If kBlockM doesn't evenly divide the tiled copy, only the last `m` needs to be checked if (EvenM || m < size<1>(tdOsdO) - 1 || get<0>(tQcQ(_0{}, m, _0{})) < kBlockM) { bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; #pragma unroll for (int k = 0; k < size<2>(tdOsdO); ++k) { cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); } } } Tensor tLSEgdPsum_cur = tLSEgdPsum(_, _, m_block); Tensor tLSEsdPsum_cur = tLSEsdPsum(_, _, smem_pipe_write); #pragma unroll for (int m = 0; m < size<1>(tLSEsdPsum); ++m) { if (get<0>(tLSEcLSE(_0{}, m)) < kBlockM) { cute::copy(gmem_tiled_copy_lse, tLSEgdPsum_cur(_, m), tLSEsdPsum_cur(_, m)); } } }; int m_block = m_block_min; // Note, using the for_each() function here to ensure `stage` is of type Int. for_each(make_int_sequence{}, [&] (auto stage) { static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; if constexpr (!Is_last_stage || kStages == 1) { if (Is_first_stage || m_block + stage < m_block_max) { load_Q_LSE(m_block + stage, stage); } } // We want the fence outside the if statement to have a fixed number of cp.async commits. // so that we can wait with the correct number of outstanding commits. cute::cp_async_fence(); if constexpr (stage < kStages_dO) { if (Is_first_stage || m_block + stage < m_block_max) { load_dO_dPsum(m_block + stage, stage); } cute::cp_async_fence(); } }); int smem_pipe_read = 0, smem_pipe_read_do = 0, smem_pipe_write = kStages - 1, smem_pipe_write_do = 0; auto load_Q_next = [&] { // if (cute::thread0()) { printf("m_block = %d, m_block_max = %d, smem_pipe_write = %d\n", m_block, m_block_max, smem_pipe_write); } if (m_block + (kStages > 1 ? kStages - 1 : 1) < m_block_max) { load_Q_LSE(m_block + (kStages > 1 ? kStages - 1 : 1), kStages > 1 ? smem_pipe_write : 0); } cute::cp_async_fence(); }; auto load_dO_next = [&] { // int smem_pipe_write_do_cur = Q_dO_same_stages ? smem_pipe_write : smem_pipe_write_do; if (m_block + kStages_dO < m_block_max) { // load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do_cur : 0); load_dO_dPsum(m_block + kStages_dO, kStages_dO > 1 ? smem_pipe_write_do : 0); } cute::cp_async_fence(); }; clear(tdKrdK); clear(tdVrdV); auto bwd_step = [&](int m_block, auto mask_fn) { Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); clear(tSrS); flash::cp_async_wait<(kStages > 1) ? 1 : 0>(); __syncthreads(); Tensor tSrQ = mma_partition_fragment_AB(thr_mma_SdP, sQ(_, _, _0{})); Tensor tSrK = mma_partition_fragment_AB(thr_mma_SdP, sK); // if (cute::thread0()) { print(tiled_mma_SdP); print(tSrS); printf("\n"); print(tSrQ); printf("\n"); print(tSrK); printf("\n"); print(tSsQ); printf("\n"); print(tSsK); printf("\n"); } flash::gemm_sm80( tSrS, tSrQ, tSrK, tSsQ(_, _, _, kStages > 1 ? smem_pipe_read : 0), tSsK, tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, nullptr /*hook*/); Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tSsLSE(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffleLSE) { cute::copy(tSsLSE(_, kStages > 1 ? smem_pipe_read : 0), tLSErLSE); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values tLSErLSE(i) = tSsLSE((thread_idx % 32) / 4 + i * 8, kStages > 1 ? smem_pipe_read : 0); } } if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } // Reshape tSrS from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh // if (cute::thread0()) { print_tensor(scores); } auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); mask_fn(tSrS, m_block); #pragma unroll for (int mi = 0; mi < size<0>(scores); ++mi) { float const lse_scaled = [&] { if constexpr (!ShuffleLSE) return tLSErLSE(mi); else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(scores); ++ni) { scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); } } Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); clear(tdPrdP); int smem_pipe_read_do_cur = Q_dO_same_stages ? smem_pipe_read : smem_pipe_read_do; flash::cp_async_wait<(kStages_dO > 1) ? 1 : 0>(); __syncthreads(); auto hook = cute::conditional_return<(kStages > 1)>(load_Q_next, nullptr); Tensor tdPrdO = mma_partition_fragment_AB(thr_mma_SdP, sdO(_, _, _0{})); Tensor tdPrV_cur = cute::conditional_return(tdPrV, mma_partition_fragment_AB(thr_mma_SdP, sV)); flash::gemm_sm80( tdPrdP, tdPrdO, tdPrV_cur, tdPsdO(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tdPsV, tiled_mma_SdP, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV, hook); Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tSsdPsum(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffledPsum) { cute::copy(tSsdPsum(_, kStages_dO > 1 ? smem_pipe_read_do_cur : 0), tLSErdPsum); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { tLSErdPsum(i) = tSsdPsum((thread_idx % 32) / 4 + i * 8, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); } } // Reshape tdPrdP from (4, MMA_N, MMA_M) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { float const dP_sum_cur = [&] { if constexpr (!ShuffledPsum) return tLSErdPsum(mi); else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } } } // if (cute::thread0()) { print_tensor(dS); } // Convert scores from fp32 to fp16/bf16 Tensor rP = make_tensor_like(tSrS); flash::convert_type_out(tSrS, rP); if constexpr (!Mma_dKV_is_RS) { Tensor tPaP = r2s_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(r2s_tiled_copy_PdS, tPaP, tPsP); } Tensor rdS = make_tensor_like(tdPrdP); flash::convert_type_out(tdPrdP, rdS); if constexpr (!Mma_dKV_is_RS) { __syncthreads(); } // Make sure P is written // For hdim 64, It's faster to write to smem_dS first before the dV gemm Tensor tdSadS = r2s_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(r2s_tiled_copy_PdS, tdSadS, tdSsdS); Tensor tdVrdO = mma_partition_fragment_AB(thr_mma_dKV, sdOt(_, _, _0{})); Tensor tdVsdO_cur = tdVsdOt(_, _, _, kStages_dO > 1 ? smem_pipe_read_do_cur : 0); if constexpr (Mma_dKV_is_RS) { Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); flash::gemm_rs_sm80(tdVrdV, tdVrP, tdVrdO, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); } else { Tensor tdVrP = mma_partition_fragment_AB(thr_mma_dKV, sPt); flash::gemm_sm80( tdVrdV, tdVrP, tdVrdO, tdVsPt, tdVsdO_cur, tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, nullptr); } // if (cute::thread0()) { print_tensor(tdVrdV); } __syncthreads(); // make sure sdS is written auto do_mma_dQ = [&] (auto hook) { Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); clear(tdQrdQ); Tensor tdQrdS = mma_partition_fragment_AB(thr_mma_dQ, sdS); Tensor tdQrK = mma_partition_fragment_AB(thr_mma_dQ, sKt); flash::gemm_sm80( tdQrdQ, tdQrdS, tdQrK, tdQsdS, tdQsKt, tiled_mma_dQ, // smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, load_dO_next); smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt, hook); // if (cute::thread0()) { print_tensor(tdQrdQ); } // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQrdQ_atomic = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); Tensor tdQgdQaccum_atomic = tdQgdQaccum(_, _, m_block); static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } }; // If kStages == 1, we want to do Mma_dK first so we can start loading Q for the next iteration if constexpr (kStages > 1) { do_mma_dQ(load_dO_next); } Tensor tdKrQ = mma_partition_fragment_AB(thr_mma_dKV, sQt(_, _, _0{})); Tensor tdKsQ_cur = tdKsQt(_, _, _, kStages > 1 ? smem_pipe_read : 0); if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); flash::gemm_rs_sm80(tdKrdK, tdKrdS, tdKrQ, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_QdOt, smem_thr_copy_QdOt); } else { Tensor tdKrdS = mma_partition_fragment_AB(thr_mma_dKV, sdSt); flash::gemm_sm80( tdKrdK, tdKrdS, tdKrQ, tdKsdSt, tdKsQ_cur, tiled_mma_dKV, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt, cute::conditional_return<(kStages > 1)>(nullptr, load_dO_next)); } if constexpr (kStages == 1) { __syncthreads(); do_mma_dQ(load_Q_next); } // if (cute::thread0()) { print_tensor(tdKrdK); } smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; smem_pipe_read_do = smem_pipe_read_do < kStages_dO - 1 ? smem_pipe_read_do + 1 : 0; smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; smem_pipe_write_do = smem_pipe_write_do < kStages_dO - 1 ? smem_pipe_write_do + 1 : 0; }; // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { bwd_step(m_block, mask_fn); } } static constexpr int kBlockN = get<1>(TileShape_MNK{}); int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations ? m_block_max : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max_before_local_mask; ++m_block) { bwd_step(m_block, mask_fn); } if constexpr (Is_local && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max; ++m_block) { bwd_step(m_block, mask_fn); } } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } return true; } }; } // namespace flash ================================================ FILE: hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "named_barrier.hpp" #include "seqlen.h" #include "block.h" #include "mask.h" #include "softmax.h" #include "utils.h" #include "copy_sm90_bulk_reduce.hpp" namespace flash { using namespace cute; template struct CollectiveMainloopBwdSm90 { static constexpr int kStages = Stages; static constexpr int kStages_dO = Stages_dO; static constexpr int kStages_dS = Stages_dS; static_assert(kStages >= kStages_dO); static_assert(Stages_dS == 1 || Stages_dS == kStages); static_assert(!Mma_dP_is_RS || SdP_swapAB_); // If Mma_dP_is_RS, we need SdP_SwapAB using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; static constexpr bool SdP_swapAB = SdP_swapAB_; static constexpr bool dKV_swapAB = dKV_swapAB_; static constexpr bool dQ_swapAB = dQ_swapAB_; static constexpr bool Q_dO_same_stages = kStages == kStages_dO; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using SeqlenInfo_t = flash::SeqlenInfoQK; using BlockMN_t = flash::BlockMN; static_assert(ArchTag::kMinComputeCapability >= 90); static_assert(get<0>(ClusterShape{}) == 1 && get<2>(ClusterShape{}) == 1); static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarp * 2; static_assert(NumMmaWarpGroups % AtomLayoutMSdP == 0); static_assert(NumMmaWarpGroups % AtomLayoutNdKV == 0); static_assert(NumMmaWarpGroups % AtomLayoutMdQ == 0); static constexpr bool Mma_dKV_is_RS = AtomLayoutMSdP == 1 && AtomLayoutNdKV == NumMmaWarpGroups && SdP_swapAB && !dKV_swapAB; static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == NumMmaWarpGroups && AtomLayoutMdQ == NumMmaWarpGroups && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS static constexpr GMMA::Major PdS_Major = GMMA::Major::K; // static constexpr GMMA::Major PdS_Major = GMMA::Major::MN; static constexpr GMMA::Major PdSt_Major = PdS_Major == GMMA::Major::K ? GMMA::Major::MN : GMMA::Major::K; using TileShapeAtomSdP = std::conditional_t< !SdP_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutSdP = std::conditional_t< !SdP_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmaSdP = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutSdP{})); using TiledMmadPRS = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutSdP{})); using TileShapeAtomdKV = std::conditional_t< !dKV_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdKV = std::conditional_t< !dKV_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmadKV = decltype(cute::make_tiled_mma( std::conditional_t< Mma_dKV_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdKV{})); using TileShapeAtomdQ = std::conditional_t< !dQ_swapAB, Shape, Int, Int>, Shape, Int, Int> >; using AtomLayoutdQ = std::conditional_t< !dQ_swapAB, Layout, Int, _1>>, Layout, Int, _1>> >; using TiledMmadQ = decltype(cute::make_tiled_mma( std::conditional_t< Mma_dQ_is_RS, decltype(cute::GMMA::rs_op_selector()), decltype(cute::GMMA::ss_op_selector()) >{}, AtomLayoutdQ{})); // We need to accommodate both Q and Q^T (and dO and dO^T) in shared memory. // Q & dO are used in the SdP Mma and Q^T and dO^T are used in the dKV Mma. // Since this is GMMA::Major::K, the M dimension (kBlockM) doesn't matter for the layout, only the K dimension // changes the layout. using SmemLayoutAtomQdO = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); // for dKV_Mma using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutdO = decltype(tile_to_shape(SmemLayoutAtomQdO{}, make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); using SmemLayoutAtomPdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector, Int>()); using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}, Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); // Need stride to be multiple of 32, otherwise we get error (misaligned address) when doing TMA if e.g. kBlockM=80 // We set stride to be multiple of 64 so that if ShuffleLSE, even if threads read from sLSE but out of bounds, // it's still a valid smem address. using SmemLayoutLSE = cute::Layout, Int>, cute::Stride<_1, Int>>; using SmemLayoutLSEMma = std::conditional_t< SdP_swapAB, cute::Layout, Int, Int>, cute::Stride<_0, _1, Int>>, cute::Layout, Int, Int>, cute::Stride<_1, _0, Int>> >; // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutQt = decltype(cute::composition(SmemLayoutQ{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutdOt = decltype(cute::composition(SmemLayoutdO{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), make_stride(Int{}, _1{}, Int{})))); using SmemLayoutKt = decltype(cute::composition(SmemLayoutK{}, make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), make_stride(Int{}, _1{})))); using SmemLayoutPdSt = decltype(cute::composition(SmemLayoutPdS{}, make_layout(make_shape(Int{}, Int{}, Int{}), make_stride(Int{}, _1{}, Int{})))); // Thread layout, 256 or 384 threads per row // We split into NumMmaWarpGroups so that we can do Bulk reduce add for each WG separately. using R2SLayoutAtomdQaccum = Layout, Int>>; using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom, ElementAccum>{}, R2SLayoutAtomdQaccum{}, Layout>{})); // Val layout, 4 vals per store using SmemLayoutdQaccum = Layout, Int>>; static constexpr int kNumPdSStore = kBlockM * kBlockN / NumMmaThreads; // If !SdP_swapAB, the accum registers hold P / dS, otherwise they hold Pt / dSt. // If PdS_major is MN, then we need to "transpose" the write. using SmemCopyAtomPdS = Copy_Atom< std::conditional_t<(!SdP_swapAB) ^ (PdS_Major == GMMA::Major::MN), std::conditional_t, std::conditional_t >, Element >; using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape{}))); using GmemTiledCopyKV = cute::SM90_TMA_LOAD; using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; using TMA_QdO = decltype(make_tma_copy_A_sm90( GmemTiledCopyQdO{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), take<0, 2>(SmemLayoutQ{}), TileShape_MNK{}, ClusterShape{})); // mcast along N mode for this M load, if any using TMA_K = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutK{}, TileShape_MNK{}, ClusterShape{})); // no mcast for KV using TMA_V = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQKV{}), SmemLayoutV{}, TileShape_MNK{}, ClusterShape{})); // no mcast for KV using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename MainloopPipeline::PipelineState; using MainloopPipeline_dO = typename cutlass::PipelineTmaAsync; using PipelineState_dO = typename MainloopPipeline_dO::PipelineState; // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(take<0, 2>(SmemLayoutQ{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(SmemLayoutK{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(SmemLayoutV{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesLSE = static_cast(size(select<0>(SmemLayoutLSE{})) * cutlass::sizeof_bits_v / 8); // These are tuned for speed. They don't affect correctness. // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. // For hdim 192, separating masking iterations results in register spills. static constexpr bool SeparateMaskingIterations = kHeadDim <= 64; // Do we keep the LSE and dPsum in each thread, or split them across 8 threads that share them and then // shuffle to get the value whenever we need? This can reduce register pressure when SdP_swapAB, where each // thread needs to keep statistics for (kBlockM / 4) rows. If !SdP_swapAB, each thread only needs to keep // statistic for 2 rows. static constexpr bool ShuffleLSE = SdP_swapAB && kHeadDim <= 64; static constexpr bool ShuffledPsum = SdP_swapAB && kHeadDim <= 64; static constexpr bool dQacc_use_TMA = kHeadDim < 256; // For hdim256, we want to slice the dQ MMA (64 x 256 on 2 WGs) into two (64 x 128 on 2 WGs) so that we can // do atomic add on one half before doing the other half of the MMA, to reduce register pressure. static constexpr bool Slice_dQKV_Mma = kHeadDim == 256 && !dQacc_use_TMA && dQ_swapAB && AtomLayoutMdQ == 1 && NumMmaWarpGroups == 2; static_assert(!(Deterministic && Slice_dQKV_Mma), "Deterministic mode not supported with Slice_dQKV_Mma"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); static constexpr size_t SmemAlignmentdS = cutlass::detail::alignment_for_swizzle(SmemLayoutPdS{}); // Without this SmemAlignment, with hdim 256 we get "misaligned address" error in TMA static constexpr size_t SmemAlignmentQKVdO = kHeadDim % 256 == 0 ? 256 : 128; static constexpr size_t SmemAlignmentV = !Mma_dP_is_RS ? SmemAlignmentQKVdO : cutlass::detail::alignment_for_swizzle(SmemLayoutV{}); static_assert(SmemAlignmentP >= 128 && SmemAlignmentdS >= 128, "Require at least 128B alignment"); // TODO: do we have to worry that smem_dk and smem_dv in the epilogue don't line up w smem_k and smem_v due to alignment? using SmemdQacc_t = std::conditional_t, cute::array_aligned>>; using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; struct TensorStorage : cute::aligned_struct { cute::array_aligned, SmemAlignmentQKVdO> smem_k; cute::array_aligned, SmemAlignmentV> smem_v; SmemdQacc_t smem_dqacc; cute::array_aligned, SmemAlignmentQKVdO> smem_q; cute::array_aligned, SmemAlignmentQKVdO> smem_do; cute::array_aligned, 128> smem_lse; cute::array_aligned, 128> smem_dpsum; SmemP_t smem_p; cute::array_aligned, SmemAlignmentdS> smem_ds; }; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQKV const stride_Q; Element const* const ptr_K; ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum const stride_dQaccum; float const* const ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale; int const window_size_left, window_size_right, attention_chunk; float const softcap_val; int const num_batch; int* const dq_semaphore; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; }; // Device side kernel params struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; ShapeQKV const shape_V; ShapeQKV const shape_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; cutlass::FastDivmod qhead_per_khead_divmod; TMA_QdO tma_load_Q, tma_load_dO; TMA_K tma_load_K; TMA_V tma_load_V; float const* const ptr_LSE_log2; ShapeLSE const shape_LSE; StrideLSE const stride_LSE_log2; float const* const ptr_dPsum; StrideLSE const stride_dPsum; float const softmax_scale, softmax_scale_log2; int const window_size_left, window_size_right; cutlass::FastDivmod attention_chunk_divmod; float const softcap_val; int const num_batch; int* const dq_semaphore; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); TMA_QdO tma_load_Q = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mQ, SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, SmemLayoutdO{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); TMA_K tma_load_K = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mK, SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, SmemLayoutV{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV if constexpr (Deterministic) { assert(args.dq_semaphore != nullptr); } // Avoid dividing by zero cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). // In the backward, we need to multiply by // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale. // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. return {args.shape_Q, args.shape_K, args.shape_V, args.shape_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.window_size_left, args.window_size_right, attention_chunk_divmod, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, args.seqused_q, args.seqused_k}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_dO.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } template CUTLASS_DEVICE void load(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do, SharedStorage &shared_storage, SchedulerPrefetch const& scheduler_prefetch, cute::tuple block_coord ) { auto [n_block, bidh, bidb] = block_coord; SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Loading Q, K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { scheduler_prefetch(); return; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSE{}); Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSE{}); int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gdO = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mdO), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) Tensor gK = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gV = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) Tensor gLSE = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mLSE), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor gdPsum = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded), mdPsum), select<0>(TileShape_MNK{}), make_coord(_)); // (M, _) Tensor sK_x = make_tensor(sK.data(), make_layout(sK.layout(), Layout<_1>{})); Tensor gK_x = make_tensor(gK.data(), make_layout(gK.layout(), Layout<_1>{})); Tensor sV_x = make_tensor(sV.data(), make_layout(sV.layout(), Layout<_1>{})); Tensor gV_x = make_tensor(gV.data(), make_layout(gV.layout(), Layout<_1>{})); // auto [tQgQ, tQsQ] = tma_partition(params.tma_load_Q, block_rank_in_cluster, Layout{}, // group_modes<0, 2>(sQ), group_modes<0, 2>(gQ)); // (TMA, k), (TMA, PIPE) // auto [tdOgdO, tdOsdO] = tma_partition(params.tma_load_dO, block_rank_in_cluster, Layout{}, // group_modes<0, 2>(sdO), group_modes<0, 2>(gdO)); // (TMA, k), (TMA, PIPE) auto block_tma_Q = params.tma_load_Q.get_slice(cluster_local_block_id.y); auto block_tma_dO = params.tma_load_dO.get_slice(cluster_local_block_id.y); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); Tensor tdOgdO = group_modes<0, 3>(block_tma_dO.partition_S(gdO)); Tensor tdOsdO = group_modes<0, 3>(block_tma_dO.partition_D(sdO)); auto [tKgK, tKsK] = tma_partition(params.tma_load_K, _0{}, Layout<_1>{}, group_modes<0, 2>(sK_x), group_modes<0, 2>(gK_x)); // (TMA), (TMA) auto [tVgV, tVsV] = tma_partition(params.tma_load_V, _0{}, Layout<_1>{}, group_modes<0, 2>(sV_x), group_modes<0, 2>(gV_x)); // (TMA), (TMA) auto bulk_copy = Copy_Traits{}; uint16_t mcast_mask_qdo = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int n = 0; n < size<1>(block_layout); ++n) { mcast_mask_qdo |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, _0{})); } } int m_block = m_block_min; int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tQgQ(_, m_block), tQsQ(_, smem_pipe_write.index())); copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), gLSE(_, m_block), sLSE(_, smem_pipe_write.index())); } // // Wait for the MMA warpgroups to say that smem_k and smem_v are ready // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); if (lane_predicate) { // Copy K tile and V tile from GMEM to SMEM. shared_storage.pipelines.barrier_KV.arrive_and_expect_tx(TmaTransactionBytesK + TmaTransactionBytesV); copy(params.tma_load_K.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tKgK, tKsK); copy(params.tma_load_V.with(reinterpret_cast(shared_storage.pipelines.barrier_KV), 0 /*mcast_mask*/), tVgV, tVsV); #pragma unroll (kHeadDim < 256 ? 2 : 1) for (; m_block < m_block_max - 1; ++m_block) { // If Q and dO have the same number of stages, we can use the same pipeline state variable // to reduce registers PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); pipeline_do.producer_acquire(smem_pipe_write_do_cur); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } ++smem_pipe_write; pipeline_q.producer_acquire(smem_pipe_write); copy(params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tQgQ(_, m_block + 1), tQsQ(_, smem_pipe_write.index())); copy(bulk_copy.with(*pipeline_q.producer_get_barrier(smem_pipe_write)), gLSE(_, m_block + 1), sLSE(_, smem_pipe_write.index())); } } scheduler_prefetch(); if (lane_predicate) { PipelineState_dO smem_pipe_write_do_cur = cute::conditional_return(smem_pipe_write, smem_pipe_write_do); pipeline_do.producer_acquire(smem_pipe_write_do_cur); copy(params.tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur), mcast_mask_qdo, TMA::CacheHintSm90::EVICT_LAST), tdOgdO(_, m_block), tdOsdO(_, smem_pipe_write_do_cur.index())); copy(bulk_copy.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do_cur)), gdPsum(_, m_block), sdPsum(_, smem_pipe_write_do_cur.index())); if constexpr (!Q_dO_same_stages) { ++smem_pipe_write_do; } ++smem_pipe_write; } if constexpr (Q_dO_same_stages) { smem_pipe_write_do = smem_pipe_write; } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write) { static_assert(Q_dO_same_stages, "Q and dO must have the same number of stages"); // Need to copy since pipeline_q.producer_tail(smem_pipe_write) will increment smem_pipe_write PipelineState smem_pipe_write_do = smem_pipe_write; // Issue the epilogue waits if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write); pipeline_do.producer_tail(smem_pipe_write_do); } } /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_write, PipelineState_dO& smem_pipe_write_do) { // Issue the epilogue waits if (cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_q.producer_tail(smem_pipe_write); pipeline_do.producer_tail(smem_pipe_write_do); } } template CUTLASS_DEVICE void store_dq(Params const& params, SharedStorage &shared_storage, cute::tuple block_coord ) { if constexpr (!dQacc_use_TMA) { return; } auto [n_block, bidh, bidb] = block_coord; SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early // Though if local and deterministic, still need to increment dq semaphore if constexpr ((Is_causal || Is_local || Varlen) && !(Is_local && Deterministic)) { if (m_block_max <= m_block_min) { return; } } Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); static constexpr int dQ_TMA_num_bytes = CUTE_STATIC_V(size<0>(sdQ)) * sizeof(ElementAccum); bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) int const num_batch = params.num_batch; int const num_head = get<2>(params.shape_Q); int *lock_ptr = !Deterministic ? nullptr : params.dq_semaphore + bidb * num_head + bidh; using Barrier = cutlass::GenericBarrier; bool const lane_predicate = cute::elect_one_sync(); int m_block = m_block_min; constexpr int kBlockM = get<0>(TileShape_MNK{}); constexpr int kBlockN = get<1>(TileShape_MNK{}); int n_block_global_max = cute::ceil_div(seqlen_info.seqlen_k, kBlockN); #pragma unroll 2 for (; m_block < m_block_max; ++m_block) { if constexpr (Deterministic) { if constexpr(Is_causal) { int n_block_max_for_m_block = std::min(n_block_global_max, cute::ceil_div((m_block + 1) * kBlockM + seqlen_info.seqlen_k - seqlen_info.seqlen_q, kBlockN)); Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block_max_for_m_block - 1 - n_block); } else { Barrier::wait_eq(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head, n_block); } } #pragma unroll for (int warpgroup_idx = 0; warpgroup_idx < NumMmaWarpGroups; ++warpgroup_idx) { cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warpgroup_idx /*id*/); // sdQ full, to be written to gmem if (lane_predicate) { SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdQ(_, warpgroup_idx).data()), raw_pointer_cast(gdQaccum(_, warpgroup_idx, m_block).data()), dQ_TMA_num_bytes, static_cast(TMA::CacheHintSm90::EVICT_LAST)); tma_store_arrive(); } } // Note, the for_each() function is required here to ensure `warpgroup_idx` is of type Int. for_each(make_int_sequence{}, [&] (auto warpgroup_idx) { if (lane_predicate) { tma_store_wait(); } cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warpgroup_idx /*id*/); // sdQ empty, ready to be written to }); if constexpr (Deterministic) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } } if constexpr (Is_local && Deterministic) { int const m_block_global_max = cute::ceil_div(seqlen_info.seqlen_q, kBlockM); #pragma unroll 2 for (; m_block < m_block_global_max; ++m_block) { Barrier::arrive_inc(lock_ptr, threadIdx.x % cutlass::NumThreadsPerWarp, m_block * num_batch * num_head); } } } CUTLASS_DEVICE void mma_init() { // We're not currently using this bc we're not using persistent scheduler // // Tell producer (warp 0) that smem_k and smem_v are ready // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); if constexpr (dQacc_use_TMA) { if (warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); // sdQ empty, ready to be written to } } } template CUTLASS_DEVICE bool mma(Params const& params, MainloopPipeline pipeline_q, MainloopPipeline_dO pipeline_do, PipelineState& smem_pipe_read, PipelineState_dO& smem_pipe_read_do, FrgTensordKV& tdKrdK, FrgTensordKV& tdVrdV, int thread_idx, int &work_idx, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "dK and dV tensor must be rmem resident."); int n_block = get<0>(block_coord); int bidb = get<2>(block_coord); SeqlenInfo_t seqlen_info{ bidb, get<0>(params.shape_Q), size<0>(params.shape_K), params.cu_seqlens_q, params.cu_seqlens_k, params.seqused_q, params.seqused_k }; auto [m_block_min, m_block_max] = BlockMN_t::get_m_block_min_max( seqlen_info, n_block, bidb, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/); // It's possible to have m_block_max <= m_block_min. Exit early if constexpr (Is_causal || Is_local || Varlen) { if (m_block_max <= m_block_min) { return false; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sdO = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdO{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sQt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQt{}); Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_do.data()), SmemLayoutdOt{}); Tensor sKt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutKt{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdS{}); Tensor sP_pi = cute::as_position_independent_swizzle_tensor(sP); Tensor sPt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutPdSt{}); Tensor sPt_pi = cute::as_position_independent_swizzle_tensor(sPt); Tensor sdS = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdS{}); Tensor sdS_pi = cute::as_position_independent_swizzle_tensor(sdS); Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_ds.data()), SmemLayoutPdSt{}); Tensor sdSt_pi = cute::as_position_independent_swizzle_tensor(sdSt); Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dqacc.data()), SmemLayoutdQaccum{}); Tensor sLSEMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_lse.data()), SmemLayoutLSEMma{}); Tensor sdPsumMma = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_dpsum.data()), SmemLayoutLSEMma{}); static_assert(stride<0>(typename TiledMmaSdP::ALayout{}) == 0 and stride<0>(typename TiledMmaSdP::BLayout{}) == 0 and size<0>(typename TiledMmaSdP::ALayout{}) == cutlass::NumThreadsPerWarpGroup and size<0>(typename TiledMmaSdP::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); constexpr int MmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); Layout warp_group_thread_layout_dq = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaSdP tiled_mma_SdP; using TiledMmadP = std::conditional_t; TiledMmadP tiled_mma_dP; TiledMmadKV tiled_mma_dKV; TiledMmadQ tiled_mma_dQ; auto wg_mma_SdP = tiled_mma_SdP.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dP = tiled_mma_dP.get_slice(warp_group_thread_layout(warp_group_idx)); auto thread_mma_SdP = tiled_mma_SdP.get_thread_slice(thread_idx); auto wg_mma_dKV = tiled_mma_dKV.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_dQ = tiled_mma_dQ.get_slice(warp_group_thread_layout_dq(warp_group_idx)); auto smem_tiled_copy_PdS = make_tiled_copy_C(SmemCopyAtomPdS{}, tiled_mma_SdP); auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(thread_idx); R2STiledCopydQaccum r2s_tiled_copy_dQaccum; auto r2s_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQsdQaccum = r2s_thr_copy_dQaccum.partition_D(sdQ); // if (thread_idx == 0) { print(sdQ); printf("\n"); print(tdQsdQaccum); printf("\n"); } // Allocate "fragments/descriptors" // We have to use the templated mma_partition_fragment_AB instead of cute::conditional_return or lambda, // because some partition_fragment_A/B don't compile. // https://stackoverflow.com/questions/50051473/if-constexpr-in-c17-does-not-work-in-a-non-templated-function Tensor tSrQ = mma_partition_fragment_AB(wg_mma_SdP, sQ); Tensor tSrK = mma_partition_fragment_AB(wg_mma_SdP, sK); Tensor tdPrdO = mma_partition_fragment_AB(wg_mma_SdP, sdO); Tensor tdPrV = mma_partition_fragment_AB(wg_mma_dP, sV); Tensor tdVrdO = mma_partition_fragment_AB(wg_mma_dKV, sdOt); Tensor tdKrQ = mma_partition_fragment_AB(wg_mma_dKV, sQt); Tensor tdQrdS = mma_partition_fragment_AB(wg_mma_dQ, sdS); Tensor tdQrK = mma_partition_fragment_AB(wg_mma_dQ, sKt); Tensor tPsP = smem_thr_copy_PdS.partition_D(cute::conditional_return(sP_pi, sPt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor tdSsdS = smem_thr_copy_PdS.partition_D(cute::conditional_return(sdS_pi, sdSt_pi)); // ((Atom,AtomNum),PIPE_M,PIPE_N) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_PdS); print(sP_pi); printf("\n"); print(sPt_pi); printf("\n"); print(tPsP); printf("\n"); print(tdSsdS); printf("\n"); } // thread_mma_SdP.partition_C(sLSEMma) has shape ((2, 2, V), MMA_M, MMA_N, PIPE), we only take the col indices // or row indices, depending on whether SdP_swapAB. Tensor tLSEsLSE = cute::conditional_return( group_modes<0, 2>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), // (2, MMA_M, PIPE) group_modes<0, 3>(thread_mma_SdP.partition_C(sLSEMma)(make_coord(_, _0{}, _), _0{}, _, _))); // (2, V, MMA_N, PIPE) Tensor tLSEsdPsum = cute::conditional_return( group_modes<0, 2>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_0{}, _, _0{}), _, _0{}, _)), group_modes<0, 3>(thread_mma_SdP.partition_C(sdPsumMma)(make_coord(_, _0{}, _), _0{}, _, _))); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(sLSEMma); printf("\n"); print(tLSEsLSE); printf("\n"); } // If we want to split the stats among the 8 threads that share the same rows. static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(tLSEsLSE))::value, 8); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; int bidh = get<1>(block_coord); int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; // For the case where we do atomicAdd directly to gdQaccum instead of using TMA bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); // if (blockIdx.x == 0 && threadIdx.x == 128) { print(mdQaccum); printf("\n"); print(gdQaccum_); printf("\n"); print(gdQaccum); printf("\n"); print(tdQgdQaccum); printf("\n"); } flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.attention_chunk_divmod, params.qhead_per_khead_divmod ); int m_block = m_block_min; clear(tdKrdK); clear(tdVrdV); // tiled_mma_dKV.accumulate_ = GMMA::ScaleOut::Zero; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.pipelines.barrier_KV.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.pipelines.barrier_KV.wait(work_idx % 2); } if constexpr (Mma_dP_is_RS) { using SmemCopyAtomV = Copy_Atom; auto smem_tiled_copy_V = make_tiled_copy_A(SmemCopyAtomV{}, tiled_mma_dP); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); Tensor tdPrV_copy_view = smem_thr_copy_V.retile_D(tdPrV); Tensor tdPsV_copy_view = smem_thr_copy_V.partition_S(cute::as_position_independent_swizzle_tensor(sV)); cute::copy(smem_tiled_copy_V, tdPsV_copy_view, tdPrV_copy_view); } auto bwd_step = [&](int m_block, auto mask_fn) { Tensor tSrS = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); consumer_wait(pipeline_q, smem_pipe_read); flash::gemm(tiled_mma_SdP, tSrQ(_, _, _, smem_pipe_read.index()), tSrK, tSrS); Tensor tLSErLSE = cute::conditional_return(make_fragment_like(tLSEsLSE(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffleLSE) { cute::copy(tLSEsLSE(_, smem_pipe_read.index()), tLSErLSE); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { // It's ok to read OOB, since we made sure sLSE is large enough and we won't use the OOB values tLSErLSE(i) = tLSEsLSE((thread_idx % 32) / 4 + i * 8, smem_pipe_read.index()); } } Tensor tdPrdP = partition_fragment_C(tiled_mma_SdP, select(TileShape_MNK{})); PipelineState_dO smem_pipe_read_do_cur = cute::conditional_return(smem_pipe_read, smem_pipe_read_do); consumer_wait(pipeline_do, smem_pipe_read_do_cur); flash::gemm(tiled_mma_dP, tdPrdO(_, _, _, smem_pipe_read_do_cur.index()), tdPrV, tdPrdP); warpgroup_wait<1>(); if constexpr (Has_softcap) { flash::apply_softcap(tSrS, params.softcap_val); } // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); // dtanh needs to happen before masking, otherwise we get 1 - (-inf)^2 = NaN in the dtanh auto dtanh = [&] { if constexpr (Has_softcap) return flash::calculate_dtanh(scores); else return nullptr; }(); mask_fn(tSrS, m_block); #pragma unroll for (int mi = 0; mi < size<0>(scores); ++mi) { float const lse_scaled = [&] { if constexpr (!ShuffleLSE) return tLSErLSE(mi); else return __shfl_sync(0xffffffff, tLSErLSE(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(scores); ++ni) { scores(mi, ni) = exp2f(scores(mi, ni) * params.softmax_scale_log2 - lse_scaled); } } Tensor tLSErdPsum = cute::conditional_return(make_fragment_like(tLSEsdPsum(_, _0{})), make_tensor(Int{})); if constexpr (!ShuffledPsum) { cute::copy(tLSEsdPsum(_, smem_pipe_read_do_cur.index()), tLSErdPsum); } else { #pragma unroll for (int i = 0; i < kStatsPerThread; ++i) { tLSErdPsum(i) = tLSEsdPsum((thread_idx % 32) / 4 + i * 8, smem_pipe_read_do_cur.index()); } } warpgroup_wait<0>(); // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); #pragma unroll for (int mi = 0; mi < size<0>(dS); ++mi) { float const dP_sum_cur = [&] { if constexpr (!ShuffledPsum) return tLSErdPsum(mi); else return __shfl_sync(0xffffffff, tLSErdPsum(mi / 8), (mi % 8) * 4 + (thread_idx % 4)); }(); #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); if constexpr (Has_softcap) { dS(mi, ni) *= dtanh(mi, ni); } } } // Convert scores from fp32 to fp16/bf16 Tensor rP = make_tensor_like(tSrS); flash::convert_type_out(tSrS, rP); if constexpr (!Mma_dKV_is_RS) { // Need to sync to make sure P has already been used in the previous iteration before writing new values if constexpr (kStages_dS == 1) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); } Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); } Tensor rdS = make_tensor_like(tdPrdP); flash::convert_type_out(tdPrdP, rdS); // If there's double buffering on dS, we don't need to sync here. // Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ. // But because both WGs have to sync at the end of the loop and double buffering, // this race condition is not possible. // This sync is to ensure (1) P is written in case of !Mma_dKV_is_RS and // (2) dS is already read by the Mma in the previous iteration in case of Mma_dKV_is_RS. if constexpr (!Mma_dKV_is_RS || (kStages_dS == 1 && Mma_dKV_is_RS)) { cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); } // For hdim 64, It's faster to write to smem_dS first before the dV gemm Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index()))); if constexpr (!Slice_dQKV_Mma) { // Most cases take this path, except for hdim256 where we want to slice to reduce register pressure if constexpr (Mma_dKV_is_RS) { Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); flash::gemm(tiled_mma_dKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } else { Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); } // SMEM fence to make sure sdS is written before it's read by WGMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO if constexpr (Mma_dKV_is_RS) { Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); flash::gemm(tiled_mma_dKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } else { Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } if constexpr (dQacc_use_TMA) { int const warp_group_idx = flash::canonical_warp_group_idx_nosync() - 1; cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQEmptyWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem Tensor taccdQrdQ = r2s_thr_copy_dQaccum.retile_S(tdQrdQ); cute::copy(r2s_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::dQFullWG1) + warp_group_idx /*id*/); // sdQ full, to be written to gmem } else { // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); static_assert(CUTE_STATIC_V(size(tdQrdQ_atomic)) == CUTE_STATIC_V(size(tdQgdQaccum_atomic))); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } } } else { // Slice_dQKV_Mma static_assert(!(Slice_dQKV_Mma && Mma_dKV_is_RS)); Tensor tdVrP = mma_partition_fragment_AB(wg_mma_dKV, sPt); Tensor tdVrP_cur = tdVrP(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(BwdNamedBarriers::PdS) /*id*/); Tensor tdQrdQ = partition_fragment_C(tiled_mma_dQ, select(TileShape_MNK{})); Tensor tdQrdS_cur = tdQrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); flash::gemm(tiled_mma_dKV, tdVrP_cur, tdVrdO(_, _, _, smem_pipe_read_do_cur.index()), tdVrdV); Tensor tdQrdQ_atomic = recast(r2s_thr_copy_dQaccum.retile_S(tdQrdQ)); Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); #pragma unroll for (int i = 0; i < size(tdQrdQ_atomic) / 2; ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } Tensor tdKrdS = mma_partition_fragment_AB(wg_mma_dKV, sdSt); Tensor tdKrdS_cur = tdKrdS(_, _, _, cute::conditional_return(_0{}, smem_pipe_read.index())); flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); pipeline_do.consumer_release(smem_pipe_read_do_cur); // release dO flash::gemm(tiled_mma_dQ, tdQrdS_cur, tdQrK, tdQrdQ); #pragma unroll for (int i = size(tdQrdQ_atomic) / 2; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } flash::gemm(tiled_mma_dKV, tdKrdS_cur, tdKrQ(_, _, _, smem_pipe_read.index()), tdKrdK); } warpgroup_wait<0>(); pipeline_q.consumer_release(smem_pipe_read); // release Q ++smem_pipe_read; if constexpr (!Q_dO_same_stages) { ++smem_pipe_read_do; } }; // We have separate iterations with causal masking. Not necessary for hdim 128 but for hdim 64 // this helps quite a bit to not have to do causal masking for most of the iterations. if constexpr ((Is_causal || Is_local) && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; static constexpr int kBlockM = get<0>(TileShape_MNK{}); int const m_block_masking_max = ((n_block + 1) * kBlockN - 1 + seqlen_q - seqlen_k - params.window_size_right) / kBlockM + 1; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < std::min(m_block_max, m_block_masking_max); ++m_block) { bwd_step(m_block, mask_fn); } } static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); int const m_block_max_before_local_mask = !Is_local || !SeparateMaskingIterations ? m_block_max : std::min(m_block_max, (n_block * kBlockN + seqlen_q - seqlen_k + params.window_size_left) / kBlockM); auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max_before_local_mask; ++m_block) { bwd_step(m_block, mask_fn); } if constexpr (Is_local && SeparateMaskingIterations) { auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply(tSrS, m_block, n_block); }; CUTLASS_PRAGMA_NO_UNROLL for (; m_block < m_block_max; ++m_block) { bwd_step(m_block, mask_fn); } } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(tdVrdV); } #pragma unroll for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.softmax_scale; } if constexpr (Q_dO_same_stages) { smem_pipe_read_do = smem_pipe_read; } ++work_idx; return true; } }; } // namespace flash ================================================ FILE: hopper/mainloop_fwd_sm80.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include "cute/tensor.hpp" #include "seqlen.h" #include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" #include "rotary.h" #include "utils.h" namespace flash { using namespace cute; template struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; static_assert(kStages > 0, "kStages must be greater than 0"); using TileShape_MNK = TileShape_MNK_; using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; static constexpr bool PagedKV = PagedKV_; static constexpr bool AppendKV = AppendKV_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool Transpose_V = Is_FP8; static_assert(ArchTag::kMinComputeCapability >= 80); static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using SeqlenInfo_t = flash::SeqlenInfoQKNewK; using BlockMN_t = flash::BlockMN; using MMA_Atom_Arch = std::conditional_t< ArchTag::kMinComputeCapability >= 80, std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >, MMA_Atom >; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group Tile, _16, _16>>; static constexpr int NumMmaThreads = size(TiledMma{}); static constexpr int NumProducerThreads = NumMmaThreads; // For compatibility with TileScheduler static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. static constexpr int kBytePerRow = kHeadDim * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomQKV = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQKV{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomQKV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomQKV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutVt = decltype( composition(SmemLayoutV{}, make_ordered_layout(make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), Step<_2, _1, _3>{}))); using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using GmemCopyAtom = Copy_Atom, AutoVectorizingCopyWithAssumedAlignment<128> >, Element>; static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopyQKV = decltype( make_tiled_copy(GmemCopyAtom{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per read // So that we don't have to check if we overshot kBlockM when we load Q static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0); // For AppendKV, We want each thread to have at least 2 loads in the K direction since in the case of // non-interleaved rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), // each thread will load twice from the same row. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); static constexpr int kBlockKGmemAppend = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRowAppend = kBlockKGmemAppend / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRowAppend == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRowAppend"); // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRowAppend == 0, "kGmemThreadsPerRowAppend must divide NumThreadsPerWarp"); using GmemLayoutAtomAppend = Layout, Int>, Stride, _1>>; // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtomAppend{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRowAppend"); using GmemTiledCopyAppendKV = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomAppend{}, Layout>>{})); // Val layout, 8 or 16 vals per store using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQK = cute::Stride; using StrideV = StrideQK; // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) using StridePageTable = cute::Stride; using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) using StrideRotary = cute::Stride; using StrideDescale = cute::Stride; static constexpr bool Share_QV_Smem = Q_in_regs; struct TensorStorageSharedQV : cute::aligned_struct<128> { union { cute::array_aligned> smem_v; cute::array_aligned> smem_q; }; cute::array_aligned> smem_k; }; struct TensorStorageSeparateQV : cute::aligned_struct<128> { cute::array_aligned> smem_v; cute::array_aligned> smem_k; cute::array_aligned> smem_q; }; using TensorStorage = std::conditional_t; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; Element* const ptr_K; // Not Element const* since we might append to KV cache in-place ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; Element const* const ptr_Qv; StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; Element const* const ptr_rotary_sin; StrideRotary const stride_rotary_sin; bool const is_rotary_interleaved; int const* const ptr_pagetable; ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const cu_seqlens_k_new = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; }; // Device side kernel params struct Params { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; ShapeQPacked const shape_Q_packed; StrideQPacked const stride_Q_packed; Element* const ptr_K; ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; Element const* const ptr_rotary_sin; StrideRotary const stride_rotary_sin; bool const is_rotary_interleaved; int const* const ptr_pagetable; ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; cutlass::FastDivmod page_size_divmod; cutlass::FastDivmod qhead_per_khead_divmod; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const cu_seqlens_k_new = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( args.shape_Q, make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) ); auto const stride_Q_packed = cute::conditional_return( args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); // Avoid dividing by zero cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, cutlass::FastDivmod(int(get<0>(args.shape_K))), cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } template CUTLASS_DEVICE bool mma(Params const& params, FrgTensorO& tOrO, Softmax& softmax, int const thread_idx, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; auto n_block_min_max = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); int const n_block_min = get<0>(n_block_min_max); int const n_block_max = get<1>(n_block_min_max); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); Tensor sVt = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(thread_idx); auto gmem_thr0_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(_0{}); // For index calculation Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(thread_idx); // Allocate "fragments/descriptors" Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx); auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // Predicates Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; // Prologue: load Q, K, V // If persistent, we don't need to wait for the previous work_idx to finish // since we assume that all MMA threads sync in the epilogue before writing to smem_o. // So any thread gets there, all threads must have finished the previous MMA and at least started // writing to smem_o. // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v if constexpr (Share_QV_Smem) { __syncthreads(); } if constexpr (!PackGQA) { Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor cQ = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); Tensor t0QcQ = gmem_thr0_copy_QKV.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } // Instead of passing in tQcQ, we pass in t0QcQ and subtract the offset from the limit // (seqlen_q - m_block * kBlockM). This is because the entries of t0QcQ are known at compile time. // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, t0QcQ, tQpQ, seqlen_info.seqlen_q - m_block * kBlockM - get<0>(tQcQ(_0{}, _0{}, _0{})) ); } else { using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element>; PackGQAt::load_Q(mQ, sQ, params.qhead_per_khead_divmod, thread_idx, seqlen_q, m_block); } cute::cp_async_fence(); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); auto load_K = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; if constexpr (!PagedKV) { // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_write); // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN ? seqlen_info.seqlen_k - n_block * kBlockN : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN))); // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. flash::copy( gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit); } else { paged_kv_manager.template load_page_table(n_block); paged_kv_manager.template load_K(n_block, sK(_, _, smem_pipe_write)); } }; auto load_V = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; if constexpr (!PagedKV) { // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_write); // We don't call flash::copy since it doesn't support bound checking // to not overshot kBlockN when writing to smem. Tensor tVgV_cur = tVgV(_, _, _, n_block); int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKVcKV(_0{}, m, _0{})) < kBlockN) { bool const predicate_n = !Seqlenk_mask || get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV_cur(_, m, k), tVsV_cur(_, m, k)); } } } } else { paged_kv_manager.template load_V(n_block, sV(_, _, smem_pipe_write)); } }; auto preprocess_Q = [&] { if constexpr (!AppendKV) { flash::cp_async_wait(); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, seqlen_info.seqlen_rotary); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = cute::conditional_return( rotary.template load_cos_sin(m_block), rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) ); flash::cp_async_wait(); __syncthreads(); rotary.apply_Q_interleaved(sQ, tRrCos, tRrSin, m_block, qhead_per_khead); } else { auto [tRrCosCont, tRrSinCont] = cute::conditional_return( rotary.template load_cos_sin(m_block), rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) ); flash::cp_async_wait(); __syncthreads(); rotary.apply_Q_contiguous(sQ, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); } } else { flash::cp_async_wait(); } } if constexpr (Q_in_regs) { __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(sQ); cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } }; // If Share_QV_Smem, we load Q, then load 1 stage of K, then (optionally) rotate Q and // read from smem_q to registers, then load V. // If !Share_QV, Smem, we load Q, load all stages of K & V, then (optionally) rotate Q. if constexpr (Share_QV_Smem) { load_K(n_block, 0, cute::true_type{} /*Seqlenk_mask*/); cute::cp_async_fence(); preprocess_Q(); __syncthreads(); // Make sure all threads have read smem_q before loading V } // For persistent, make sure all threads have finished reading smem_o if constexpr (!Share_QV_Smem) { __syncthreads(); } // Note, using the for_each() function here to ensure `stage` is of type Int. for_each(make_int_sequence{}, [&] (auto stage) { static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; if constexpr (!Share_QV_Smem || !Is_first_stage) { if (Is_first_stage || n_block - stage >= n_block_min) { load_K(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); } // We want the fence outside the if statement to have a fixed number of cp.async commits. // so that we can wait with the correct number of outstanding commits. cute::cp_async_fence(); } if constexpr (!Is_last_stage) { if (Is_first_stage || n_block - stage >= n_block_min) { load_V(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); } cute::cp_async_fence(); } }); if constexpr (!Share_QV_Smem) { preprocess_Q(); } flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.attention_chunk_divmod, params.qhead_per_khead_divmod ); float softcap_val = params.softcap_val; if constexpr (Has_softcap && Is_FP8) { float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; softcap_val *= q_descale * k_descale; } // Softcapping needs to happen before masking since if we apply after masking, softcapping can turn // -inf to e.g. -50.0, which can affect the attention softmax. auto scoremod_premask_fn = [&](auto& tSrS) { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; int smem_pipe_read = 0, smem_pipe_write = kStages - 1; auto load_K_next = [&] { if (n_block - kStages >= n_block_min) { load_K(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); } cute::cp_async_fence(); }; auto sync = [&] { flash::cp_async_wait(); __syncthreads(); }; clear(tOrO); auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); clear(tSrS); sync(); auto load_V_next = [&] { if (n_block - kStages + 1 >= n_block_min) { load_V(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); } cute::cp_async_fence(); }; Tensor tSrQ_cur = cute::conditional_return(tSrQ, thr_mma.partition_fragment_A(sQ)); Tensor tSrK = thr_mma.partition_fragment_B(sK(_, _, _0{})); flash::gemm_sm80( tSrS, tSrQ_cur, tSrK, tSsQ, tSsK(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K, load_V_next ); smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; scoremod_premask_fn(tSrS); // Faster to load_K before gemm if we only have 1 stage if constexpr (kStages == 1) { sync(); load_K_next(); } mask_fn(tSrS, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); softmax.template online_softmax(tSrS); if constexpr (Is_FP8) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (kStages > 1) { sync(); } Tensor tOrV = thr_mma.partition_fragment_B(sVt(_, _, _0{})); flash::gemm_rs_sm80(tOrO, tOrP, tOrV, tOsVt(_, _, _, kStages > 1 ? smem_pipe_read : 0), tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); if constexpr (kStages > 1) { load_K_next(); } smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( seqlen_info, m_block, n_block_min, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( seqlen_info, m_block, n_block_min, params.window_size_left, params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); } // Separate masking iterations on the left for local attention if constexpr (Is_local) { auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } } float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8) { flash::permute_output_fp8(tOrO); } return true; } template CUTLASS_DEVICE bool store_kv_new(Params const& params, int const thread_idx, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; auto n_block_new_min_max = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); int const n_block_new_min = get<0>(n_block_new_min_max); int const n_block_new_max = get<1>(n_block_new_min_max); if (n_block_new_max <= n_block_new_min) { return false; } Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutV{}); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew = make_tensor(make_gmem_ptr(params.ptr_K_new), params.shape_K_new, params.stride_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor mVnew = make_tensor(make_gmem_ptr(params.ptr_V_new), params.shape_K_new, params.stride_V_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gKnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gVnew = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mVnew), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, seqlen_info.seqlen_rotary); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); static_assert(std::is_same_v); static_assert(!PagedKV || std::is_same_v); GmemTiledCopyQKV gmem_tiled_copy_kv_g2s; auto gmem_thr_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(thread_idx); auto gmem_thr0_copy_kv_g2s = gmem_tiled_copy_kv_g2s.get_thread_slice(_0{}); // Only for index calculation GmemTiledCopyAppendKV gmem_tiled_copy_kv_s2g; auto gmem_thr_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(thread_idx); auto gmem_thr0_copy_kv_s2g = gmem_tiled_copy_kv_s2g.get_thread_slice(_0{}); // Only for index calculation Tensor tKgKnew = gmem_thr_copy_kv_g2s.partition_S(gKnew); Tensor tKsKg2s = gmem_thr_copy_kv_g2s.partition_S(sK); Tensor tKsKs2g = gmem_thr_copy_kv_s2g.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tKgK = gmem_thr_copy_kv_s2g.partition_D(gK); Tensor tVgVnew = gmem_thr_copy_kv_g2s.partition_S(gVnew); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tVsVg2s = gmem_thr_copy_kv_g2s.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tVsVs2g = gmem_thr_copy_kv_s2g.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tVgV = gmem_thr_copy_kv_s2g.partition_D(gV); Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcKg2s = gmem_thr_copy_kv_g2s.partition_D(cK); Tensor t0KcKg2s = gmem_thr0_copy_kv_g2s.partition_D(cK); Tensor tKpKg2s = make_tensor(make_shape(size<2>(tKsKg2s))); Tensor tKcKs2g = gmem_thr_copy_kv_s2g.partition_D(cK); Tensor t0KcKs2g = gmem_thr0_copy_kv_s2g.partition_D(cK); Tensor tKpKs2g = make_tensor(make_shape(size<2>(tKsKs2g))); #pragma unroll for (int k = 0; k < size(tKpKg2s); ++k) { tKpKg2s(k) = get<1>(tKcKg2s(_0{}, _0{}, k)) < get<1>(params.shape_K); } #pragma unroll for (int k = 0; k < size(tKpKs2g); ++k) { tKpKs2g(k) = get<1>(tKcKs2g(_0{}, _0{}, k)) < get<1>(params.shape_K); } auto load_K_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; Tensor tKsK_cur = tKsKg2s(_, _, _, smem_pipe_write); int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN ? seqlen_k_new - n_block * kBlockN : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); // We don't need to clear the sK smem tiles since we won't write them out flash::copy( gmem_tiled_copy_kv_g2s, tKgKnew(_, _, _, n_block), tKsK_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); }; auto load_V_new = [&] (int const n_block, int const smem_pipe_write, auto need_seqlenk_masking_type) { static constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; Tensor tVsV_cur = tVsVg2s(_, _, _, smem_pipe_write); int const seqlenk_row_limit = -int(get<0>(tKcKg2s(_0{}, _0{}, _0{}))) + (EvenN ? seqlen_k_new - n_block * kBlockN : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k_new - n_block * kBlockN, kBlockN))); // We don't need to clear the sV smem tiles since we won't write them out flash::copy( gmem_tiled_copy_kv_g2s, tVgVnew(_, _, _, n_block), tVsV_cur, t0KcKg2s, tKpKg2s, seqlenk_row_limit); }; auto store_K = [&] (int const n_block, int const smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); if (get<1>(params.shape_rotary) <= 0) { Tensor tKsK_cur = tKsKs2g(_, _, _, smem_pipe_read); if constexpr (!PagedKV) { Tensor tKgK_cur = tKgK(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_kv_s2g, tKsK_cur, tKgK_cur, tKcKs2g, tKpKs2g, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) ); } else { paged_kv_manager.store_K(n_block, tKsK_cur); } } else { Tensor gK_cur = gK(_, _, n_block); auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCos, tRrSin, tPrKPtr, n_block); } else { auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read), gK_cur, tKpKs2g, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } }; auto store_V = [&] (int const n_block, int const smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); Tensor tVsV_cur = tVsVs2g(_, _, _, smem_pipe_read); if constexpr (!PagedKV) { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_kv_s2g, tVsV_cur, tVgV_cur, tKcKs2g, tKpKs2g, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } }; int n_block = n_block_new_max - 1; // Note, using the for_each() function here to ensure `stage` is of type Int. for_each(make_int_sequence{}, [&] (auto stage) { static constexpr bool Is_first_stage = CUTE_STATIC_V(stage) == 0; static constexpr bool Is_last_stage = CUTE_STATIC_V(stage) == kStages - 1; if (Is_first_stage || n_block - stage >= n_block_new_min) { load_K_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); } cute::cp_async_fence(); // If persistent, need to sync to make sure all threads have finished with smem_o before writing to smem_v if constexpr (Is_first_stage) { __syncthreads(); } if constexpr (!Is_last_stage) { if (Is_first_stage || n_block - stage >= n_block_new_min) { load_V_new(n_block - stage, stage, cute::bool_constant{} /*Seqlenk_mask*/); } cute::cp_async_fence(); } }); int smem_pipe_read = 0, smem_pipe_write = kStages - 1; #pragma unroll 1 for (; n_block >= n_block_new_min; --n_block) { if constexpr (PagedKV) { paged_kv_manager.template load_page_table(n_block); } flash::cp_async_wait(); __syncthreads(); store_K(n_block, kStages > 1 ? smem_pipe_read : 0); if (n_block - kStages + 1 >= n_block_new_min) { load_V_new(n_block - kStages + 1, kStages > 1 ? smem_pipe_write : 0, cute::bool_constant{} /*Seqlenk_mask*/); } cute::cp_async_fence(); smem_pipe_write = smem_pipe_write < kStages - 1 ? smem_pipe_write + 1 : 0; flash::cp_async_wait(); __syncthreads(); store_V(n_block, kStages > 1 ? smem_pipe_read : 0); smem_pipe_read = smem_pipe_read < kStages - 1 ? smem_pipe_read + 1 : 0; if (n_block - kStages >= n_block_new_min) { load_K_new(n_block - kStages, kStages > 1 ? smem_pipe_write : 0, cute::false_type{} /*Seqlenk_mask*/); } cute::cp_async_fence(); } return true; } }; } // namespace flash ================================================ FILE: hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/builders/sm90_common.inl" #include "named_barrier.hpp" #include "seqlen.h" #include "block.h" #include "mask.h" #include "pack_gqa.h" #include "paged_kv.h" #include "rotary.h" #include "utils.h" #include "sm90_pipeline_no_cluster.hpp" namespace flash { using namespace cute; template struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v;; static constexpr bool Is_causal = Is_causal_; static constexpr bool Is_local = Is_local_; static constexpr bool Has_softcap = Has_softcap_; static constexpr bool Varlen = Varlen_; static constexpr bool PagedKVNonTMA = PagedKVNonTMA_; static constexpr bool AppendKV = AppendKV_; static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; static constexpr bool Use_TMA_Q = !PackGQA; static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; static constexpr bool LargeHeadDimV = kHeadDimV > 256; static_assert(ArchTag::kMinComputeCapability >= 90); static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 && !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; static constexpr cute::GMMA::Major TmaMajorV = !V_colmajor ? GMMA::Major::MN : GMMA::Major::K; static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); using SeqlenInfo_t = flash::SeqlenInfoQKNewK; using BlockMN_t = flash::BlockMN; static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. static constexpr bool MmaQK_is_RS = false; // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); // Slightly faster in this case to have WG1 use RS instead of SS to avoid waiting for the P smem write static constexpr bool MmaPV_use_RS_WG1 = !MmaPV_is_RS && kHeadDim == 64 && kHeadDimV == 512; using AtomLayoutQK = Layout, _1, _1>>; using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< !MmaQK_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, AtomLayoutQK{})); using AtomLayoutPV = std::conditional_t< !LargeHeadDimV, AtomLayoutQK, Layout, _1>> >; using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, AtomLayoutPV{})); using TiledMmaQV = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQK{})); // For hdim64,512, WG1 can use RS but WG2 must use SS using TiledMmaPV_RS = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutPV{})); static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector, decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, make_shape(Int{}, shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); using SmemLayoutVMmaQV = decltype(tile_to_shape( SmemLayoutAtomVMmaQV{}, make_shape(shape<1>(TileShape_MNK_QV{}), Int{}, Int{}))); static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); // Only for LargeHeadDimV where WG0 sends WG1 the scales using SmemLayoutScale = cute::Layout, Int>>; using SmemCopyAtomP = Copy_Atom; // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; using LDSM_value_shape = Shape<_2, _2, _1, _4>; using LDSM_value_stride = Stride<_1, _2, _16, _4>; using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; using S2RTiledCopyVt = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; using STSM_value_shape = Shape<_1, _4, _2, _2>; using STSM_value_stride = Stride<_0, _1, _4, _8>; using STSM_divide_shape = Shape<_8, _16>; // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. // using STSM_value_shape = Shape<_2, _4, _1, _2>; // using STSM_value_stride = Stride<_4, _1, _0, _8>; // using STSM_divide_shape = Shape<_16, _16>; using R2STiledCopyV = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); using GmemTiledCopyQ = cute::SM90_TMA_LOAD; using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); // We use CpAsync for K and V if PagedKVNonTMA and AppendKV, since TMA doesn't work there static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKVNonTMA where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // If AppendKV, we'll be loading Q for rotary, and we assume divisibility to avoid predication static_assert(!AppendKV || kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumMmaThreads / kGmemThreadsPerRow"); using GmemTiledCopyAppendKV = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) using StrideQK = cute::Stride; using StrideV = std::conditional_t>; // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) using StridePageTable = cute::Stride; using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) using StrideRotary = cute::Stride; using StrideDescale = cute::Stride; using TMA_Q = decltype(make_tma_copy_A_sm90( GmemTiledCopyQ{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), SmemLayoutQ{}, TileShape_MNK{}, ClusterShape{})); using TMA_K = decltype(make_tma_copy_B_sm90( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{})); // mcast along M mode for this N load, if any using TMA_V = decltype(make_tma_copy( GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any using TMA_Qv_ = decltype(make_tma_copy_A_sm90( GmemTiledCopyQ{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), SmemLayoutQv{}, TileShape_MNK_QV{}, ClusterShape{})); using TMA_Qv = std::conditional_t; // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; using MainloopPipelineV = std::conditional_t>; using MainloopPipelineVt = std::conditional_t>; // We always use TMA for K_new and V_new using MainloopPipelineKVNew = PipelineTmaAsync; using PipelineState = cutlass::PipelineState; // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; }; struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; SmemP_t smem_p; }; struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; SmemP_t smem_p; SmemScale_t smem_scale; }; using TensorStorageNoTranspose = std::conditional_t< MmaPV_is_RS, TensorStorageWithoutPNoTranspose, std::conditional_t >; static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); struct TensorStorageTransposeV : cute::aligned_struct { cute::array_aligned, SmemAlignmentV> smem_v; cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemQv_t smem_qv; SmemScale_t smem_scale; }; using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. static constexpr bool UseSchedulerBarrier = (IntraWGOverlap ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) : NumMmaWarpGroups == 2) && !LargeHeadDimV; static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap; // Host side kernel arguments struct Arguments { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; Element* const ptr_K; // not Element const* since we might append to KV cache in-place ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; Element const* const ptr_Qv; StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; Element const* const ptr_rotary_sin; StrideRotary const stride_rotary_sin; bool const is_rotary_interleaved; int const* const ptr_pagetable; ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; float const softmax_scale; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; int const window_size_left = -1, window_size_right = -1, attention_chunk = 0; float const softcap_val; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const cu_seqlens_k_new = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const* const seqlens_rotary = nullptr; }; // Device side kernel params struct Params { Element const* const ptr_Q; ShapeQKV const shape_Q; StrideQK const stride_Q; ShapeQPacked const shape_Q_packed; StrideQPacked const stride_Q_packed; Element* const ptr_K; ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; Element const* const ptr_Qv; StrideV const stride_Qv; ShapeQPacked const shape_Qv_packed; StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; Element const* const ptr_rotary_sin; StrideRotary const stride_rotary_sin; bool const is_rotary_interleaved; int const* const ptr_pagetable; ShapePageTable const shape_pagetable; StridePageTable const stride_pagetable; cutlass::FastDivmod page_size_divmod; cutlass::FastDivmod blockN_per_page_size_divmod; cutlass::FastDivmod qhead_per_khead_divmod; TMA_Q tma_load_Q; TMA_K tma_load_K; TMA_V tma_load_V; TMA_K tma_load_K_new; TMA_V tma_load_V_new; TMA_Qv tma_load_Qv; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; float const softcap_val; int const window_size_left, window_size_right; cutlass::FastDivmod attention_chunk_divmod; int const num_splits; int const* const kv_batch_idx = nullptr; int const* const cu_seqlens_q = nullptr; int const* const cu_seqlens_k = nullptr; int const* const cu_seqlens_k_new = nullptr; int const* const seqused_q = nullptr; int const* const seqused_k = nullptr; int const* const leftpad_k = nullptr; int const *const seqlens_rotary = nullptr; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); TMA_Q tma_load_Q = make_tma_copy_A_sm90( GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, TileShape_MNK{}, ClusterShape{}); // no mcast for Q Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); TMA_K tma_load_K = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mK, take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( GmemTiledCopyKV{}, cute::conditional_return(mKnew, mK), take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), select<1, 0, 2, 3>(args.stride_V_new)); TMA_V tma_load_V_new = make_tma_copy( GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); TMA_Qv tma_load_Qv = [&] { if constexpr (HasQv) { return make_tma_copy_A_sm90( GmemTiledCopyQ{}, mQv, SmemLayoutQv{}, TileShape_MNK_QV{}, ClusterShape{}); // no mcast for Qv } else { return nullptr; } }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( args.shape_Q, make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) ); auto const stride_Q_packed = cute::conditional_return( args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); auto const shape_Qv_packed = cute::conditional_return( shape_Qv, make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) ); auto const stride_Qv_packed = cute::conditional_return( args.stride_Qv, make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } assert(args.num_splits >= 1); int page_size = !args.ptr_pagetable ? 1 : get<0>(args.shape_K); if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { assert(page_size % kBlockN == 0); assert(!args.leftpad_k); } // Avoid dividing by zero cutlass::FastDivmod attention_chunk_divmod(args.attention_chunk >= 1 ? args.attention_chunk : 1); attention_chunk_divmod.divisor = args.attention_chunk; // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, cutlass::FastDivmod(page_size), // page_size_divmod cutlass::FastDivmod(!args.ptr_pagetable ? 1 : cute::ceil_div(page_size, kBlockN)), // blockN_per_page_size_divmod cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, !Has_softcap ? 0.f : args.softmax_scale / args.softcap_val, args.window_size_left, args.window_size_right, attention_chunk_divmod, !Split ? 1 : args.num_splits, args.kv_batch_idx, args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new, args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_Q) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); if constexpr (HasQv) { cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); } } if constexpr (Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } if constexpr (AppendKV) { cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); } } template CUTLASS_DEVICE void load(Params const& params, MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, PipelineState& smem_pipe_write, SharedStorage &shared_storage, SchedulerPrefetch const& scheduler_prefetch, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, int &work_idx ) { // some of these are captured in lambda so can't use structured binding int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { scheduler_prefetch(); return; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sK_pi = as_position_independent_swizzle_tensor(sK); // as_position_independent_swizzle_tensor makes address calculation easier when we do LDSM & STSM to transpose. // But it requires smem_vt and smem_v to be aligned to e.g 512 bytes. Tensor sVt = [&] { if constexpr (!Transpose_V) { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); } else { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{})); } }(); // Only used if Transpose_V Tensor sV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{})); // Only used if we're using cp.async to load V Tensor sVcpasync = [&] { if constexpr (!Transpose_V) { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); } else { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); } // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K.partition_D(sK)); // (TMA, PIPE) auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) auto [tQvgQv, tQvsQv] = [&] { if constexpr (HasQv) { auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) return cute::make_tuple(tQvgQv, tQvsQv); } else { return cute::make_tuple(nullptr, nullptr); } }(); // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.blockN_per_page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx ); // Set up for transposing V, only used if Transpose_V S2RTiledCopyVt s2r_tiled_copy_vt; R2STiledCopyV r2s_tiled_copy_v; auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(thread_idx); auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(thread_idx); // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8, kStages) Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32, kStages) // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64), kStages) Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64), kStages) CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); // Faster to have 2 LDSM.T, byte permute, STSM for better ILP static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_) - 1>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_) - 1>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2), kStages) auto transpose_V = [&](int stage) { if constexpr (Transpose_V) { #pragma unroll for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}), _0{})); static_assert(size<0>(tTransrV) == 16); Tensor tTransrV_64 = recast(tTransrV); cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i), stage), tTransrV); #pragma unroll for (int j = 0; j < size(tTransrV_64); ++j) { uint32_t upper = tTransrV_64[j].x; uint32_t lower = tTransrV_64[j].y; tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); } cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i), stage)); } } }; uint16_t mcast_mask_kv = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); } } auto load_K = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { pipeline_k.producer_acquire(smem_pipe_write); if constexpr (!PagedKVNonTMA) { auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_K_TMA(); copy(params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), tKgK_TMA(_, n_block_idx, bidb_kv_idx), tKsK_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_K(n_block, sK_pi(_, _, smem_pipe_write.index())); pipeline_k.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } }; auto load_V = [&] (int const n_block, auto const& smem_pipe_write, auto need_seqlenk_masking_type) { auto pipeline_v_load = cute::conditional_return(pipeline_v, pipeline_vt); pipeline_v_load.producer_acquire(smem_pipe_write); if constexpr (!PagedKVNonTMA) { auto [n_block_idx, bidb_kv_idx] = paged_kv_manager.get_indices_for_V_TMA(); copy(params.tma_load_V.with(*pipeline_v_load.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_LAST), tVgVt_TMA(_, n_block_idx, bidb_kv_idx), tVsVt_TMA(_, smem_pipe_write.index())); } else { constexpr bool Seqlenk_mask = decltype(need_seqlenk_masking_type)::value; paged_kv_manager.template load_V(n_block, sVcpasync(_, _, smem_pipe_write.index())); pipeline_v_load.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); } }; auto copy_Vt_to_V = [&] (auto const& smem_pipe_write) { // Instead of maintaining smem_pipe_read as a separate variable, we can just use smem_pipe_write, // and exploit the invariance that smem_pipe_write.phase() == smem_pipe_read.phase() ^ 1. // This saves 1 or 2 registers. PipelineState smem_pipe_read{smem_pipe_write.index(), smem_pipe_write.phase() ^ 1, smem_pipe_write.count()}; pipeline_vt.consumer_wait(smem_pipe_read); pipeline_v.producer_acquire(smem_pipe_write); transpose_V(smem_pipe_write.index()); // SMEM fence to make sure V is transposed before math cutlass::arch::fence_view_async_shared(); pipeline_v.producer_commit(smem_pipe_write); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. Without this we get race conditions. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, cutlass::arch::ReservedNamedBarriers::TransposeBarrier /*id*/); pipeline_vt.consumer_release(smem_pipe_read); }; int n_block = n_block_max - 1; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); // If this is true, we're guaranteed that only the first warp will execute this function static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; bool should_load_KV = !Use_TMA_KV || ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()); if (should_load_KV) { if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } else { paged_kv_manager.template load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } // if (thread_idx == 0) { printf("Producer: main load, before load_K, index = %d\n", smem_pipe_write.index());} load_K(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); // if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());} } if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQgQ, tQsQ); if constexpr (HasQv) { shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQvgQv, tQvsQv); } } } else { // Load Q with cp.async cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); auto &barrier_Q = shared_storage.pipelines.barrier_Q; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); barrier_Q.arrive(); if constexpr (HasQv) { Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); barrier_Qv.arrive(); } } // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. // if (thread_idx == 0) { printf("Producer: main load, before barrier_O, work_idx = %d\n", work_idx);} shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); // if (thread_idx == 0) { printf("Producer: main load, after barrier_O\n");} if constexpr (!Transpose_V && !IntraWGOverlap) { if (should_load_KV) { load_V(n_block, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } } int n_block_prev = n_block; --n_block; #pragma unroll (!Transpose_V && Use_TMA_KV ? 2 : 1) for (; n_block >= n_block_min; --n_block) { PipelineState smem_pipe_write_v = smem_pipe_write; // copy the state, write_v is always 1 step behind ++smem_pipe_write; if (should_load_KV) { if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } else { paged_kv_manager.load_page_table_TMA(n_block); } if constexpr (Transpose_V) { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } load_K(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); if constexpr (!Transpose_V) { if constexpr (IntraWGOverlap) { load_V(n_block_prev, smem_pipe_write_v, cute::true_type{} /*Seqlenk_mask*/); } else { load_V(n_block, smem_pipe_write, cute::false_type{} /*Seqlenk_mask*/); } } } n_block_prev = n_block; if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write_v); } } scheduler_prefetch(); if constexpr (!Transpose_V && IntraWGOverlap) { if (should_load_KV) { load_V(n_block_prev, smem_pipe_write, cute::true_type{} /*Seqlenk_mask*/); } } if constexpr (Transpose_V) { copy_Vt_to_V(smem_pipe_write); } ++smem_pipe_write; // At the end, all threads have the correct smem_pipe_write. ++work_idx; } template CUTLASS_DEVICE void load_tail(MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, MainloopPipelineVt pipeline_vt, PipelineState& smem_pipe_write, SharedStorage &shared_storage, int const work_idx) { // If we don't wait for barrier_O here, when using Cluster, CTA0 might exit early and CTA1 will // try to arrive on barrier_O of CTA0, causing "unspecified launch failure". shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); // Issue the epilogue waits // TODO: check if this should be called by 1 thread or more if (warp_idx_in_warpgroup == 0 && cute::elect_one_sync()) { /* This helps avoid early exit of blocks in Cluster * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used * then would just be acquired since the phase was still inverted from make_producer_start_state */ pipeline_k.producer_tail(smem_pipe_write); pipeline_v.producer_tail(smem_pipe_write); if constexpr (Transpose_V) { pipeline_vt.producer_tail(smem_pipe_write); } } } CUTLASS_DEVICE void warp_scheduler_barrier_sync() { if constexpr (UseSchedulerBarrier) { cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); } } CUTLASS_DEVICE void warp_scheduler_barrier_arrive() { if constexpr (UseSchedulerBarrier) { static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); int const cur_WG = flash::canonical_warp_group_idx_nosync() - 1; int const next_WG = NumMmaWarpGroups == 2 ? 1 - cur_WG : (cur_WG < NumMmaWarpGroups - 1 ? cur_WG + 1 : 0); cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) + next_WG /*id*/); } } CUTLASS_DEVICE void mma_init() { int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready if (!LargeHeadDimV || warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if (LargeHeadDimV && warp_group_idx > 1) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } if constexpr (UseSchedulerBarrier) { // We have NamedBarrier for up to 3 WGs static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); // WG1 needs the very first signal to start if (warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } } template CUTLASS_DEVICE bool mma(Params const& params, MainloopPipelineK pipeline_k, MainloopPipelineV pipeline_v, PipelineState& smem_pipe_read, FrgTensorO& tOrO, Softmax& softmax, int const thread_idx, int &work_idx, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } } Tensor sQ = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = [&] { if constexpr (MmaPV_is_RS) { // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); } }(); Tensor sScale = [&] { if constexpr (LargeHeadDimV) { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); } else { // won't be used, just a placeholder return make_tensor(make_smem_ptr(static_cast(nullptr)), SmemLayoutScale{}); } }(); Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); if constexpr (!MmaQK_is_RS) { static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and stride<0>(typename TiledMmaQK::BLayout{}) == 0 and size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaQK tiled_mma_qk; TiledMmaPV tiled_mma_pv; TiledMmaQV tiled_mma_qv; auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); // For storing scales to smem, only used when LargeHeadDimV auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto store_scales = [&](auto& scales, int stage) { static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); #pragma unroll for (int mi = 0; mi < size(taccOcO_row); ++mi) { if (get<1>(taccOcO_row(_0{})) == 0) { sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); } } }; auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, params.attention_chunk_divmod, params.qhead_per_khead_divmod ); float softcap_val = params.softcap_val; if constexpr (Has_softcap && Is_FP8) { float const q_descale = params.ptr_q_descale == nullptr ? 1.0f : params.ptr_q_descale[bidb * get<0>(params.stride_q_descale) + bidh_kv * get<1>(params.stride_q_descale)]; float const k_descale = params.ptr_k_descale == nullptr ? 1.0f : params.ptr_k_descale[bidb * get<0>(params.stride_k_descale) + bidh_kv * get<1>(params.stride_k_descale)]; softcap_val *= q_descale * k_descale; } // Softcapping needs to happen before masking since if we apply after masking, softcapping // can turn -inf to e.g. -50.0, which can affect the attention softmax. auto scoremod_premask_fn = [&](auto& tSrS) { if constexpr (Has_softcap) { flash::apply_softcap(tSrS, softcap_val); } }; auto write_P_to_smem = [&](auto& tOrP) { if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); }; auto arrive_on_P_write_barrier = [&] { cutlass::arch::fence_view_async_shared(); __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } }; auto &barrier_Q = shared_storage.pipelines.barrier_Q; if constexpr (!AppendKV) { barrier_Q.wait(work_idx % 2); } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, seqlen_info.seqlen_rotary); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); int const qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor; if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = cute::conditional_return( rotary.template load_cos_sin(m_block), rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) ); barrier_Q.wait(work_idx % 2); rotary.apply_Q_interleaved(sQ_pi, tRrCos, tRrSin, m_block, qhead_per_khead); } else { auto [tRrCosCont, tRrSinCont] = cute::conditional_return( rotary.template load_cos_sin(m_block), rotary.template load_cos_sin_packgqa(m_block, params.qhead_per_khead_divmod) ); barrier_Q.wait(work_idx % 2); rotary.apply_Q_contiguous(sQ_pi, tRrCosCont, tRrSinCont, m_block, qhead_per_khead); } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } } if constexpr (MmaQK_is_RS) { using SmemCopyAtomQ = Copy_Atom; auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); cute::copy(smem_tiled_copy_Q, tSsQ_copy_view, tSrQ_copy_view); } if constexpr (IntraWGOverlap) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); if constexpr (HasQv) { shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); } scoremod_premask_fn(tSrS); mask.template apply(tSrS, m_block, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } --n_block; // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter clear(tOrO); // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; // Each step does gemm0 for iter n_block, gemm1 for iter n_block + 1, and softmax for iter n_block. auto fwd_step = [&](int const n_block, auto mask_fn, auto check_inf_type) { static constexpr bool Check_inf = decltype(check_inf_type)::value; PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); ++smem_pipe_read; Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr(!HasQv) { if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K if constexpr (HasQv) { warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); if constexpr (!HasQv) { warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!MmaPV_is_RS) { arrive_on_P_write_barrier(); } }; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( seqlen_info, m_block, n_block_min, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/); } } int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( seqlen_info, m_block, n_block_min, params.window_size_left, params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { fwd_step(n_block, no_mask_fn, cute::false_type{} /*check_inf*/); } // Separate masking iterations on the left for local attention if constexpr (Is_local) { auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::bool_constant{} /*check_inf*/); } } // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); store_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; } else { // No intra-WG overlap warp_scheduler_barrier_sync(); auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; auto smem_pipe_read_prev = smem_pipe_read; if constexpr (!Is_first_iter) { ++smem_pipe_read; } Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (!HasQv) { warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K } else { if constexpr (Is_first_iter) { shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); } consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K warpgroup_wait<0>(); } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); if constexpr (LargeHeadDimV && !Is_first_iter) { store_scales(scores_scale, smem_pipe_read_prev.index()); } softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!MmaPV_is_RS) { write_P_to_smem(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!MmaPV_is_RS && !MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } warp_scheduler_barrier_sync(); if constexpr (!MmaPV_use_RS_WG1) { flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); } else { TiledMmaPV_RS tiled_mma_pv_rs; flash::gemm(tiled_mma_pv_rs, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); } if constexpr (!MmaPV_is_RS && MmaPV_use_RS_WG1) { arrive_on_P_write_barrier(); } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V }; auto first_iter_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; fwd_step(n_block, first_iter_mask_fn, cute::true_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); --n_block; if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const n_block_min_causal_local_mask = BlockMN_t::get_n_block_min_causal_local_mask( seqlen_info, m_block, n_block_min, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); #pragma unroll 1 for (; n_block >= n_block_min_causal_local_mask; --n_block) { fwd_step(n_block, mask_fn, cute::false_type{} /*is_first_iter*/, cute::true_type{} /*check_inf*/); } } int const n_block_min_before_local_mask = BlockMN_t::get_n_block_min_before_local_mask( seqlen_info, m_block, n_block_min, params.window_size_left, params.attention_chunk_divmod, params.qhead_per_khead_divmod); auto no_mask_fn = [](auto& tSrS, int n_block) { }; #pragma unroll 1 for (; n_block >= n_block_min_before_local_mask; --n_block) { fwd_step(n_block, no_mask_fn, cute::false_type{} /*is_first_iter*/, cute::false_type{} /*check_inf*/); } // Separate masking iterations on the left for local attention if constexpr (Is_local) { auto local_mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { fwd_step(n_block, local_mask_fn, cute::false_type{} /*is_first_iter*/, cute::bool_constant{} /*check_inf*/); } } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); store_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; } ++work_idx; return true; } template CUTLASS_DEVICE bool mma_pv(Params const& params, MainloopPipelineV pipeline_v, PipelineState& smem_pipe_read, FrgTensorO& tOrO, Softmax& softmax, int const thread_idx, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, SharedStorage& shared_storage ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda int const m_block = get<0>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } } Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaPV tiled_mma_pv; auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate "fragments/descriptors" Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); // For load scales to smem, pretend thread_idx is thread_idx % 128 auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto load_scales = [&](auto& scales, int stage) { static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); #pragma unroll for (int mi = 0; mi < size(taccOcO_row); ++mi) { scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); } }; // clear(tOrO); // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; // If HasQv, then by the time P is ready, V must have been ready as well if constexpr (!HasQv) { pipeline_v.consumer_wait(smem_pipe_read); } cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; #pragma unroll 1 for (; n_block >= n_block_min; --n_block) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); softmax.rescale_o(tOrO, scores_scale); ++smem_pipe_read; if constexpr (!HasQv) { auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); pipeline_v.consumer_wait(smem_pipe_read, barrier_token); } flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; return true; } template CUTLASS_DEVICE bool load_kv_new(Params const& params, MainloopPipelineKVNew pipeline_k_new, MainloopPipelineKVNew pipeline_v_new, PipelineState& smem_pipe_write, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, int const work_idx ) { auto [m_block, bidh, bidb, split_idx] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sVt = [&] { if constexpr (!Transpose_V) { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); } }(); // int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; // Prepare the TMA loads uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) uint16_t mcast_mask_kv = 0; if constexpr (cute::is_same_v) { auto block_layout = Layout{}; // (m,n) -> block_id for (int m = 0; m < size<0>(block_layout); ++m) { mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); } } auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { pipeline_k_new.producer_acquire(smem_pipe_write); copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); }; auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { pipeline_v_new.producer_acquire(smem_pipe_write); copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); }; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); // If this is true, we're guaranteed that only the first warp will execute this function static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); int n_block = n_block_new_max - 1; // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV // and the main attention are not the same. We want to make sure the consumers // have finished reading all smem_k and smem_v for the previous iteration. shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } ++smem_pipe_write; --n_block; // if (thread_idx == 0) { printf("Producer: before for loop\n"); } #pragma unroll 1 for (; n_block >= n_block_new_min; --n_block) { if (should_load_KV) { load_K_new(n_block, smem_pipe_write); // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } load_V_new(n_block, smem_pipe_write); // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } } ++smem_pipe_write; } // if (thread_idx == 0) { printf("Producer: after for loop\n"); } // At the end, all threads have the correct smem_pipe_write. return true; } template CUTLASS_DEVICE bool store_kv_new(Params const& params, MainloopPipelineKVNew pipeline_k_new, MainloopPipelineKVNew pipeline_v_new, PipelineState& smem_pipe_read, int const thread_idx, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord ) { auto [m_block, bidh, bidb, split_idx] = block_coord; auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.attention_chunk_divmod, params.qhead_per_khead_divmod); if (n_block_new_max <= n_block_new_min) { return false; } // as_position_independent_swizzle_tensor makes address calculation easier Tensor sK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{})); // We want to use SmemLayoutVCpAsync to have shape (kBlockN, kHeadDim) instead of (kHeadDim, kBlockN) Tensor sV = [&] { if constexpr (!Transpose_V) { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVCpAsync{})); } else { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); int const seqlen_k_new = seqlen_info.seqlen_k_new; using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, seqlen_info.seqlen_rotary); // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.blockN_per_page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k, bidb_kv_idx // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); if constexpr (UseSchedulerBarrier) { // WG1 already got the very first signal from mma_init(), but we'll be using the same NamedBarrier. // So we'll need to "cancel it out" here and then re-signal it at the end. if (flash::canonical_warp_group_idx_nosync() == 1) { cutlass::arch::NamedBarrier::sync(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } static_assert(std::is_same_v); static_assert(!PagedKVNonTMA || std::is_same_v); GmemTiledCopyAppendKV gmem_tiled_copy_kv; auto gmem_thr_copy_kv = gmem_tiled_copy_kv.get_thread_slice(thread_idx); Tensor tKsK = gmem_thr_copy_kv.partition_S(sK); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tKgK = gmem_thr_copy_kv.partition_D(gK); Tensor tVsV = gmem_thr_copy_kv.partition_S(sV); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tVgV = gmem_thr_copy_kv.partition_D(gV); Tensor cK = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_D(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); #pragma unroll for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } Tensor tVpV = cute::conditional_return(tKpK, tVpV_); auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); if (get<1>(params.shape_rotary) <= 0) { pipeline_k_new.consumer_wait(smem_pipe_read); Tensor tKsK_cur = tKsK(_, _, _, smem_pipe_read.index()); if constexpr (!PagedKVNonTMA) { Tensor tKgK_cur = tKgK(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_kv, tKsK_cur, tKgK_cur, tKcK, tKpK, std::min(seqlen_k_new - n_block * kBlockN, kBlockN) ); } else { paged_kv_manager.store_K(n_block, tKsK_cur); } } else { Tensor gK_cur = gK(_, _, n_block); auto tPrKPtr = cute::conditional_return(paged_kv_manager.compute_K_ptr(), nullptr); if (params.is_rotary_interleaved) { auto [tRrCos, tRrSin] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); rotary.template apply_K_interleaved(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCos, tRrSin, tPrKPtr, n_block); } else { auto [tRrCosCont, tRrSinCont] = rotary.template load_cos_sin(n_block); pipeline_k_new.consumer_wait(smem_pipe_read); rotary.template apply_K_contiguous(sK(_, _, smem_pipe_read.index()), gK_cur, tKpK, tRrCosCont, tRrSinCont, tPrKPtr, n_block, get<1>(params.shape_K)); } } // Without this fence I'm getting race condition when seqlen_k is large cutlass::arch::fence_view_async_shared(); // Very important: PipelineTmaAsync::consumer_release assumes that the warpgroup is synchronized // before calling. cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); pipeline_k_new.consumer_release(smem_pipe_read); // if (thread_idx == 0) { print_tensor(tKpK); printf("\n"); printf("seqlen_limit = %d\n", seqlen_k_new - n_block * kBlockN);} }; auto store_V = [&] (int const n_block, auto const& smem_pipe_read) { pipeline_v_new.consumer_wait(smem_pipe_read); int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); Tensor tVsV_cur = tVsV(_, _, _, smem_pipe_read.index()); if constexpr (!PagedKVNonTMA) { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } cutlass::arch::fence_view_async_shared(); cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + flash::canonical_warp_group_idx_nosync() /*id*/); pipeline_v_new.consumer_release(smem_pipe_read); }; #pragma unroll 1 for (int n_block = n_block_new_max - 1; n_block >= n_block_new_min; --n_block) { if constexpr (PagedKVNonTMA) { paged_kv_manager.template load_page_table(n_block); } store_K(n_block, smem_pipe_read); // if (thread_idx == 0) { printf("Done storing K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } store_V(n_block, smem_pipe_read); // if (thread_idx == 0) { printf("Done storing V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } ++smem_pipe_read; } // if (thread_idx == 0) { printf("After for loop\n"); } // Re-signaling the NamedBarrier that we "canceled out" if constexpr (UseSchedulerBarrier) { if (flash::canonical_warp_group_idx_nosync() == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } return true; } }; } // namespace flash ================================================ FILE: hopper/mask.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include "cutlass/fast_math.h" // For cutlass::FastDivmod #include "utils.h" namespace flash { using namespace cute; template struct Mask { static_assert(!(PackGQA && SwapAB), "Cannot be both PackGQA and SwapAB"); int const thread_idx; int const seqlen_q, seqlen_k; int const window_size_left, window_size_right, sink_token_length; cutlass::FastDivmod const attention_chunk_divmod; cutlass::FastDivmod const qhead_per_khead_divmod; CUTLASS_DEVICE Mask(const int thread_idx, const int seqlen_q, const int seqlen_k, const int window_size_left, const int window_size_right, const int sink_token_length, cutlass::FastDivmod const &attention_chunk_divmod, cutlass::FastDivmod const &qhead_per_khead_divmod) : thread_idx(thread_idx) , seqlen_q(seqlen_q) , seqlen_k(seqlen_k) , window_size_left(window_size_left) , window_size_right(window_size_right) , sink_token_length(sink_token_length) , attention_chunk_divmod(attention_chunk_divmod) , qhead_per_khead_divmod(qhead_per_khead_divmod) { }; template CUTLASS_DEVICE void apply(Tensor &tSrS, const int m_block, const int n_block) const { static_assert(!(Causal_mask && Local_mask), "Cannot be both causal and local"); static_assert(Layout::rank == 3, "Only support 3D Tensor"); if (!Seqlenk_mask && !Causal_mask && !Local_mask) { return; } auto thread_mma = TiledMma{}.get_thread_slice(thread_idx); auto thread0_mma = TiledMma{}.get_thread_slice(_0{}); static constexpr int Row = !SwapAB ? 0 : 1, Col = !SwapAB ? 1 : 0; Tensor cS = cute::make_identity_tensor(Shape, Int>{}); Tensor tScS = thread_mma.partition_C(cS); Tensor tSrS_rowcol = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); Tensor tScS_rowcol = make_tensor(tScS.data(), flash::convert_layout_acc_rowcol(tScS.layout())); Tensor t0ScS = thread0_mma.partition_C(cS); Tensor t0ScS_rowcol = make_tensor(t0ScS.data(), flash::convert_layout_acc_rowcol(t0ScS.layout())); // We want to use the col indices of thread0 to compare, since that is known at compile time. // So we subtract the limit by the first col index of this thread (get(tScS_rowcol(_0{}, _0{}))) int const thread_col_offset = get(tScS_rowcol(_0{}, _0{})); int const seqlenk_col_limit = seqlen_k - n_block * kBlockN - thread_col_offset; if constexpr (!Causal_mask && !Local_mask) { if constexpr (Seqlenk_mask) { // Just masking based on col #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { if (int(get(t0ScS_rowcol(_0{}, n))) >= seqlenk_col_limit) { #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { tSrS_rowcol(m, n) = -INFINITY; } } } } } else { // mask based on both row and col if constexpr (!SwapAB) { // If PackGQA, we split the work of compute divmod among threads in the same row static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); static_assert(!PackGQA || CUTE_STATIC_V(size<0>(tSrS_rowcol)) <= kMmaThreadsPerRow); int mma_m_idx; // Might get OOB but it's ok since we'll check it later if constexpr (PackGQA) { mma_m_idx = qhead_per_khead_divmod.divide(m_block * kBlockM + get(tScS_rowcol(thread_idx % kMmaThreadsPerRow, _0{}))); } int const causal_row_offset = 1 + seqlen_k - n_block * kBlockN - seqlen_q - thread_col_offset; if constexpr (Causal_mask) { #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); int const col_limit_right = !Seqlenk_mask ? row_idx + causal_row_offset : __viaddmin_s32(row_idx, causal_row_offset, seqlenk_col_limit); #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { if (int(get(t0ScS_rowcol(_0{}, n))) >= col_limit_right) { tSrS_rowcol(m, n) = -INFINITY; } } } } else { int const local_row_offset_right = causal_row_offset + window_size_right; int const local_row_offset_left = causal_row_offset - 1 - window_size_left; int const col_limit_sink = sink_token_length - n_block * kBlockN; // TODO: subtract thread_col_offset? #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = !PackGQA ? get(tScS_rowcol(m, _0{})) + m_block * kBlockM : __shfl_sync(0xffffffff, mma_m_idx, m % kMmaThreadsPerRow, kMmaThreadsPerRow); int col_limit_right = !Seqlenk_mask ? row_idx + local_row_offset_right : __viaddmin_s32(row_idx, local_row_offset_right, seqlenk_col_limit); int col_limit_left = row_idx + local_row_offset_left; if (attention_chunk_divmod.divisor > 0) { int col_limit_left_chunk = flash::round_down(attention_chunk_divmod, row_idx + seqlen_k - seqlen_q) - n_block * kBlockN - thread_col_offset; col_limit_left = std::max(col_limit_left, col_limit_left_chunk); col_limit_right = std::min(col_limit_right, col_limit_left_chunk + attention_chunk_divmod.divisor); } #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col_idx = int(get(t0ScS_rowcol(m, n))); if (col_idx >= col_limit_right || (col_idx < col_limit_left && col_idx >= col_limit_sink)) { tSrS_rowcol(m, n) = -INFINITY; } } } } } else { // TODO: backward does not support attention_chunk yet int const thread_row_offset = get(tScS_rowcol(_0{}, _0{})); int const causal_row_offset = seqlenk_col_limit - seqlen_q + m_block * kBlockM + thread_row_offset; if constexpr (Causal_mask) { #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col0 = int(get(t0ScS_rowcol(_0{}, n))); // If col0 is beyond the column limit, we want to mask out the entire column, by setting // row limit to be kBlockM. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset; #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { if (int(get(t0ScS_rowcol(m, _0{}))) < row_limit_top) { tSrS_rowcol(m, n) = -INFINITY; } } } } else { int const col_limit_sink = sink_token_length - n_block * kBlockN - thread_col_offset; #pragma unroll for (int n = 0; n < size<1>(tSrS_rowcol); ++n) { int const col0 = int(get(t0ScS_rowcol(_0{}, n))); // If col0 is beyond the column limit, we want to mask out the entire column, by setting // row limit to be kBlockM. int const row_limit_top = col0 >= seqlenk_col_limit ? kBlockM : col0 - causal_row_offset - window_size_right; int const row_limit_bot = col0 < col_limit_sink ? kBlockM : col0 - causal_row_offset + window_size_left; #pragma unroll for (int m = 0; m < size<0>(tSrS_rowcol); ++m) { int const row_idx = int(get(t0ScS_rowcol(m, _0{}))); if (row_idx < row_limit_top || row_idx > row_limit_bot) { tSrS_rowcol(m, n) = -INFINITY; } } } } } } }; }; } // namespace flash ================================================ FILE: hopper/named_barrier.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cutlass/arch/barrier.h" namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// // cutlass::arch::NamedBarrier::sync/arrive are only enabled Sm90 even though they work // for Sm80 as well. We reimplement them here, enabled for both Sm90 and Sm80. CUTLASS_DEVICE static void named_barrier_sync(uint32_t num_threads, uint32_t barrier_id_) { static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); } CUTLASS_DEVICE static void named_barrier_sync(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { uint32_t barrier_id = static_cast(reserved_named_barriers); asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); } CUTLASS_DEVICE static void named_barrier_arrive(uint32_t num_threads, uint32_t barrier_id_) { static constexpr uint32_t ReservedNamedBarrierCount = static_cast(cutlass::arch::ReservedNamedBarriers::FirstUserBarrier); uint32_t barrier_id = barrier_id_ + ReservedNamedBarrierCount; cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); } CUTLASS_DEVICE static void named_barrier_arrive(uint32_t num_threads, cutlass::arch::ReservedNamedBarriers reserved_named_barriers) { uint32_t barrier_id = static_cast(reserved_named_barriers); cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); } //////////////////////////////////////////////////////////////////////////////////////////////////// // Enumerates the reserved named barriers to avoid potential conflicts enum class FwdNamedBarriers { QueryEmpty = 0, WarpSchedulerWG1 = 1, WarpSchedulerWG2 = 2, WarpSchedulerWG3 = 3, AppendKV = 4, QueryRotated = 5, PFull = 6, PEmpty = 7, }; enum class BwdNamedBarriers { KVEmpty = 0, PdS = 1, dQEmptyWG1 = 2, dQEmptyWG2 = 3, dQEmptyWG3 = 4, dQFullWG1 = 5, dQFullWG2 = 6, dQFullWG3 = 7, }; } // flash ================================================ FILE: hopper/pack_gqa.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include "cutlass/fast_math.h" // For cutlass::FastDivmod #include "utils.h" namespace flash { using namespace cute; template struct PackGQAManager { // We use CpAsync for Q, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static constexpr int kGmemElemsPerStore = kGmemElemsPerLoad; static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. static constexpr int kBytePerRow = kHeadDim * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; using GmemLayoutAtom = Layout, Int>, Stride, _1>>; using GmemTiledCopyQCpAsync = decltype( make_tiled_copy(GmemCopyAtomCpAsync{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per load // Was trying to have each WG loading Q to the rows in sQ that only that WG needs so that we only need // to sync within each WG, but didn't seem to be any faster. // using GmemLayoutAtomWG = Layout, Int, Int >, // Stride, _128, _1>>; // using GmemTiledCopyQCpAsyncWG = decltype( // make_tiled_copy(GmemCopyAtomCpAsync{}, // GmemLayoutAtomNew{}, // Layout>>{})); // Val layout, 8 or 16 vals per load using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store template CUTLASS_DEVICE static auto compute_ptr(Tensor &tensor, TensorC const &tRows, cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const m_block) { // tensor of shape ((qhead_per_khead, seqlen_q)) static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size(tRows)), NumThreadsPerRow); using TensorType = typename Engine::value_type; Tensor tPrPtr = make_tensor(Shape>{}); #pragma unroll for (int i = 0; i < NumPtrPerThread; ++i) { int const row = i * NumThreads + get<0>(tRows(thread_idx % NumThreadsPerRow)); int const idx = m_block * kBlockM + row; int m_idx, h_idx; m_idx = qhead_per_khead_divmod.divmod(h_idx, idx); tPrPtr[i] = &tensor(make_coord(make_coord(h_idx, m_idx))); } return tPrPtr; } template CUTLASS_DEVICE static void load_Q(TensormQ const &mQ, // ((qhead_per_khead, seqlen_q), headdim) TensorsQ &sQ, // (kBlockM, kHeadDim) cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const seqlen_q, int const m_block ) { GmemTiledCopyQCpAsync gmem_tiled_copy_Q_cp_async; // GmemTiledCopyQCpAsyncNew gmem_tiled_copy_Q_cp_async; auto gmem_thr_copy_Q_cp_async = gmem_tiled_copy_Q_cp_async.get_thread_slice(thread_idx); Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tQcQ = gmem_thr_copy_Q_cp_async.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tQsQ = gmem_thr_copy_Q_cp_async.partition_D(sQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) // Tensor tQcQ_ = gmem_thr_copy_Q_cp_async.partition_S(cute::flat_divide(cQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) // Tensor tQsQ_ = gmem_thr_copy_Q_cp_async.partition_D(cute::flat_divide(sQ, _64{})); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) // Tensor tQcQ = group_modes<1, rank(tQcQ_) - 1>(tQcQ_); // Tensor tQsQ = group_modes<1, rank(tQsQ_) - 1>(tQsQ_); Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < size<1>(mQ); } // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for Q. // We split the work among threads loading the same row of Q, then __shfl_sync the pointers. Tensor mQ_0 = mQ(_, _0{}); Tensor tQcQ_row = tQcQ(_0{}, _, _0{}); Tensor tPrQPtr = compute_ptr(mQ_0, tQcQ_row, qhead_per_khead_divmod, thread_idx, m_block); int const qhead_per_khead = qhead_per_khead_divmod.divisor; #pragma unroll for (int m = 0; m < size<1>(tQsQ); ++m) { int idx = m_block * kBlockM + get<0>(tQcQ(_0{}, m, _0{})); Element const* q_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrQPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); if (idx < seqlen_q * qhead_per_khead) { // if (thread_idx == 0) { printf("m: %d, m_idx: %d, h_idx: %d, q_ptr = %p, q_ptr_og = %p\n", m, m_idx, h_idx, q_ptr, &mQ_copy(0, make_coord(h_idx, m_idx), 0));} Tensor mQ_cur = make_tensor(make_gmem_ptr(q_ptr), Shape>{}); Tensor mQ_cur_copy = cute::tiled_divide(mQ_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tQsQ); ++k) { int ki = get<1>(tQcQ(_0{}, _0{}, k)) / kGmemElemsPerLoad; // the "tiled_copy.with(tQpQ(k))"" will fill in zero for columns where tQpQ(k) is false // TODO: check this cute::copy(gmem_tiled_copy_Q_cp_async.with(tQpQ(k)), mQ_cur_copy(_, ki), tQsQ(_, m, k)); } } // Don't need to fill in 0s for sQ since we're not gonna write the output to gmem for those rows } }; template CUTLASS_DEVICE static void store_LSE(TensormLSE &mLSE, // ((qhead_per_khead, seqlen_q)) TensorsLSE const &tLSErLSE, // (kBlockM) split across threads according to tiled_mma TiledMma tiled_mma, cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const seqlen_o, int const m_block ) { Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO_row = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()))(_, _0{}); CUTE_STATIC_ASSERT_V(size(tLSErLSE) == size(taccOcO_row)); // MMA_M // If PackGQA, we split the work of compute divmod among threads in the same row static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); static_assert(CUTE_STATIC_V(size(tLSErLSE)) <= kMmaThreadsPerRow); static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); Tensor tPrLSEPtr = compute_ptr(mLSE, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); static_assert(CUTE_STATIC_V(size(tPrLSEPtr)) == 1); int const qhead_per_khead = qhead_per_khead_divmod.divisor; #pragma unroll for (int mi = 0; mi < size(tLSErLSE); ++mi) { int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); float* ptr_LSE_cur = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrLSEPtr[0]), mi % kMmaThreadsPerRow, kMmaThreadsPerRow)); if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o * qhead_per_khead) { *ptr_LSE_cur = tLSErLSE(mi); } } }; template CUTLASS_DEVICE static void store_O(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to gmem_tiled_copy_O cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const seqlen_o, int const m_block ) { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); Tensor cO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < size<1>(mO); } // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. // We split the work among threads loading the same row of O, then __shfl_sync the pointers. Tensor mO_0 = mO(_, _0{}); Tensor tOcO_row = tOcO(_0{}, _, _0{}); Tensor tPrOPtr = compute_ptr(mO_0, tOcO_row, qhead_per_khead_divmod, thread_idx, m_block); int const qhead_per_khead = qhead_per_khead_divmod.divisor; #pragma unroll for (int m = 0; m < size<1>(tOrO); ++m) { int idx = m_block * kBlockM + get<0>(tOcO(_0{}, m, _0{})); Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); if (idx < seqlen_o * qhead_per_khead) { Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOrO); ++k) { int ki = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerStore; if (tOpO(k)) { cute::copy(gmem_tiled_copy_O, tOrO(_, m, k), mO_cur_copy(_, ki)); } } } } }; template CUTLASS_DEVICE static void store_O_direct(TensormO &mO, // ((qhead_per_khead, seqlen_o), headdim) TensorrO const &tOrO, // (kBlockM, kHeadDim) split across threads according to tiled_mma TiledMma tiled_mma, cutlass::FastDivmod const &qhead_per_khead_divmod, int const thread_idx, int const seqlen_o, int const m_block ) { static constexpr int kGmemElemsPerStoreDirect = 2; cute::Copy_Atom, Element> gmem_copy_direct; // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout())); Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int>{}); Tensor caccO = cute::make_identity_tensor(Shape, Int>{}); auto thread_mma = tiled_mma.get_thread_slice(thread_idx); Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); Tensor taccOcO_col = taccOcO_rowcol(_0{}, _); // If PackGQA, we split the work of compute divmod among threads in the same row static constexpr int kMmaThreadsPerRow = size<0, 0>(typename TiledMma::AtomLayoutC_TV{}); static_assert(cutlass::NumThreadsPerWarp % kMmaThreadsPerRow == 0); static_assert(CUTE_STATIC_V(size(taccOcO_row)) <= kMmaThreadsPerRow); // Similar to loading K and V when PagedKV, it's expensive to compute the pointers for O. // We split the work among threads loading the same row of O, then __shfl_sync the pointers. Tensor mO_0 = mO(_, _0{}); Tensor tPrOPtr = compute_ptr(mO_0, taccOcO_row, qhead_per_khead_divmod, thread_idx, m_block); static_assert(CUTE_STATIC_V(size(tPrOPtr)) == 1); int const qhead_per_khead = qhead_per_khead_divmod.divisor; #pragma unroll for (int m = 0; m < size<1>(tOrO_copy); ++m) { int row = m_block * kBlockM + get<0>(taccOcO_row(m)); Element* o_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrOPtr[0]), m % kMmaThreadsPerRow, kMmaThreadsPerRow)); if (row < seqlen_o * qhead_per_khead) { Tensor mO_cur = make_tensor(make_gmem_ptr(o_ptr), Shape>{}); Tensor mO_cur_copy = cute::tiled_divide(mO_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tOrO_copy); ++k) { int col = get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)); if (col < size<1>(mO)) { cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), mO_cur_copy(_, col / kGmemElemsPerStoreDirect)); } } } } }; }; } // namespace flash ================================================ FILE: hopper/padding.py ================================================ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py import torch import torch.nn.functional as F from einops import rearrange def unpad_input(hidden_states, attention_mask, unused_mask=None): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. indices: (total_nnz), the indices of masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. """ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. return ( rearrange(hidden_states, "b s ... -> (b s) ...")[indices], indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch, ) def pad_input(hidden_states, indices, batch, seqlen): """ Arguments: hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. batch: int, batch size for the padded sequence. seqlen: int, maximum sequence length for the padded sequence. Return: hidden_states: (batch, seqlen, ...) """ dim = hidden_states.shape[1:] output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype) output[indices] = hidden_states return rearrange(output, "(b s) ... -> b s ...", b=batch) ================================================ FILE: hopper/paged_kv.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include "cutlass/fast_math.h" // For cutlass::FastDivmod #include "utils.h" namespace flash { using namespace cute; template struct PagedKVManager { // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), // load_page_table(2), load_K(2), load_V(1), etc. // So we need to compute the V pointers for the previous iteration. // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for // rotary where we want each thread to have at least 2 loads per row. static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); // We use CpAsync for K and V if PagedKV, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); // We assume threads loading the same row are in the same warp. This is for an optimization in PagedKV where // these threads share the same page table entry and share the work of computing pointers to paged K and paged V. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using GmemCopyAtomCpAsync = cute::Copy_Atom, Element>; using GmemLayoutAtomKVCpAsync = Layout, Int>, Stride, _1>>; using GmemTiledCopyKVCpAsync = decltype( make_tiled_copy(GmemCopyAtomCpAsync{}, GmemLayoutAtomKVCpAsync{}, Layout>>{})); // Val layout, 8 or 16 vals per load using GmemTiledCopyKVStore = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomKVCpAsync{}, Layout>>{})); // Val layout, 8 or 16 vals per load using ShapeKV = cute::Shape; // (seqlen, d, head, batch) using StrideKV = cute::Stride; using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) using StridePageTable = cute::Stride; using TensorPageTable = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapePageTable{}, StridePageTable{})(int(0), _)); using TensorKV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeKV{}, StrideKV{})(_, _, int(0), _)); using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, // since those require int64_t arithmetic. We optimize by having threads split this work. // Typically there are 8 threads loading per row (e.g. hdim 64 and 128), and there are 11 rows // that each thread needs to load for the case of hdim 128 and kBlockN = 176. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); using TensorPageOffset = decltype(make_tensor>(Shape>{})); using TensorKVPtr = decltype(make_tensor(Shape>{})); GmemTiledCopyKVCpAsync gmem_tiled_copy_kv; cutlass::FastDivmod const &page_size_divmod; cutlass::FastDivmod const &blockN_per_page_size_divmod; int const thread_idx; int const seqlen_k; int const leftpad_k; int const* const ptr_page_table; GmemThrCopyKVCpAsync const gmem_thr_copy_kv; TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; TensortKpK tKpK; TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; int bidb_kv_idx, bidb_kv_idx_prev, n_block_idx, n_block_idx_prev; // Only used for TMA CUTLASS_DEVICE PagedKVManager(int const* const ptr_page_table_, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, cutlass::FastDivmod const &blockN_per_page_size_divmod, int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k, int bidb_kv_idx ) : page_size_divmod(page_size_divmod) , blockN_per_page_size_divmod(blockN_per_page_size_divmod) , thread_idx(thread_idx) , seqlen_k(seqlen_k) , leftpad_k(leftpad_k) , ptr_page_table(ptr_page_table_) , gmem_thr_copy_kv(gmem_tiled_copy_kv.get_thread_slice(thread_idx)) , bidb_kv_idx(bidb_kv_idx) , bidb_kv_idx_prev(bidb_kv_idx) { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); #pragma unroll for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); #pragma unroll for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } tVpV = cute::conditional_return(tKpK, tVpV_); }; template CUTLASS_DEVICE void load_page_table(const int n_block) { // The uncoalesced gmem load is intentional. This is so that each thread only loads the page table entries // it needs, and we don't need any sync between warps. // Assuming 8 threads per row, and 176 rows, then the rows from 0 to 175 are loaded by // threads 0, 8, 16, ..., 120, 1, 9, ..., 121, 2, 10, ..., 122, etc. #pragma unroll for (int i = 0; i < kPageEntryPerThread; ++i) { int const row = i * NumThreads + (thread_idx % kGmemThreadsPerRow) * (NumThreads / kGmemThreadsPerRow) + (thread_idx / kGmemThreadsPerRow); int const row_idx = n_block * kBlockN + row; int page_idx, page_offset; page_idx = page_size_divmod.divmod(page_offset, row_idx + leftpad_k); // Add the condition (i + 1) * NumThreads <= kBlockN since that is an upper bound of row // and is known at compile time. It avoids branching when e.g., kBlockN = 176 and i = 0. int const page = ((i + 1) * NumThreads <= kBlockN || row < kBlockN) && (!Seqlenk_mask || row_idx < seqlen_k) ? mPageTable[page_idx] : 0; tPrPageOffset[i] = {page, page_offset}; // if (cute::thread0()) { printf("row = %d, page_idx = %d, page_offset = %d, page = %d, leftpad_k = %d, seqlen_k = %d\n", row, page_idx, page_offset, page, leftpad_k, seqlen_k); } } if constexpr (First_iter && !KV_Same_Iter) { compute_V_ptr(); } }; template CUTLASS_DEVICE void load_page_table_TMA(const int n_block) { // We require that page size is a multiple of kBlockN, and there's no leftpad_k if (ptr_page_table) { bidb_kv_idx = mPageTable[blockN_per_page_size_divmod.divmod(n_block_idx, n_block)]; } else { n_block_idx = n_block; } if constexpr (First_iter && !KV_Same_Iter) { bidb_kv_idx_prev = bidb_kv_idx; n_block_idx_prev = n_block_idx; } }; CUTLASS_DEVICE cute::tuple get_indices_for_K_TMA() { return {n_block_idx, bidb_kv_idx}; }; CUTLASS_DEVICE cute::tuple get_indices_for_V_TMA() { if constexpr (KV_Same_Iter) { return {n_block_idx, bidb_kv_idx}; } else { cute::tuple const indices = {n_block_idx_prev, bidb_kv_idx_prev}; bidb_kv_idx_prev = bidb_kv_idx; n_block_idx_prev = n_block_idx; return indices; } }; CUTLASS_DEVICE TensorKVPtr compute_K_ptr() { Tensor tPrKPtr = make_tensor(Shape>{}); #pragma unroll for (int i = 0; i < kPageEntryPerThread; ++i) { auto [page, page_offset] = tPrPageOffset[i]; tPrKPtr[i] = &mK_paged(page_offset, _0{}, page); } return tPrKPtr; }; CUTLASS_DEVICE void compute_V_ptr() { #pragma unroll for (int i = 0; i < kPageEntryPerThread; ++i) { auto [page, page_offset] = tPrPageOffset[i]; tPrVPtr[i] = &mV_paged(page_offset, _0{}, page); } }; template CUTLASS_DEVICE void load_K(const int n_block, TensorK &&sK) { // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; Tensor tPrKPtr = compute_K_ptr(); // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tKsK = gmem_thr_copy_kv.partition_D(sK); Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); // We want to use the row indices of thread0 to compare, since that is known at compile time. // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) int const seqlenk_row_limit = -int(get<0>(tKcK(_0{}, _0{}, _0{}))) + (EvenN ? seqlen_k - n_block * kBlockN : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); #pragma unroll for (int m = 0; m < size<1>(tKsK); ++m) { bool const should_load = EvenN ? (!Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit) : get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; Element const* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k)), mK_paged_cur_copy(_, ki), tKsK(_, m, k)); } } // Don't need to clear out the rest of the smem since we'll mask out the scores anyway } }; template CUTLASS_DEVICE void load_V(const int n_block, TensorV &&sV) { // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtomKVCpAsync{})) == 0; if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // Faster to rely on the cp.async to clear smem that are out of bound, // rather than calling cute::clear directly. // We have to be careful not to write to smem past `kBlockN` if !EvenN. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); } } } if constexpr (!KV_Same_Iter) { compute_V_ptr(); } }; template CUTLASS_DEVICE void store_K(const int n_block, TensorK &&tKrK) { Tensor tPrKPtr = compute_K_ptr(); // We're using the same partitioning as GmemTiledCopyKVCpAsync (used for loading) // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; // We want to use the row indices of thread0 to compare, since that is known at compile time. // So we subtract the limit by the first row index of this thread (get<0>(tKcK(_0{}, _0{}, _0{}))) // int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); // if (threadIdx.x == 128) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_k = %d, seqlenk_row_limit = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_k, seqlenk_row_limit); } #pragma unroll for (int m = 0; m < size<1>(tKrK); ++m) { bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_paged_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); Tensor mK_paged_cur_copy = cute::tiled_divide(mK_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tKrK); ++k) { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (tKpK(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tKrK(_, m, k), mK_paged_cur_copy(_, ki)); } } } } }; template CUTLASS_DEVICE void store_V(const int n_block, TensorV &&tVrV) { if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVrV); ++m) { bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tVrV); ++k) { int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; if (tVpV(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); } } } } if constexpr (!KV_Same_Iter) { compute_V_ptr(); } }; }; } // namespace flash ================================================ FILE: hopper/rotary.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void apply_rotary_interleaved(Tensor &rK, Tensor const &rCos, Tensor const &rSin) { CUTE_STATIC_ASSERT_V(rank(rK) == _1{}); CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); static_assert(decltype(size<0>(rK))::value == decltype(size<0>(rCos))::value * 2); static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 Tensor K_fp32 = make_tensor_like(rK); convert_type_out(rK, K_fp32); Tensor cos_fp32 = make_tensor_like(rCos); convert_type_out(rCos, cos_fp32); Tensor sin_fp32 = make_tensor_like(rSin); convert_type_out(rSin, sin_fp32); #pragma unroll for (int i = 0; i < size<0>(K_fp32) / 2; ++i) { float real = K_fp32[2 * i] * cos_fp32[i] - K_fp32[2 * i + 1] * sin_fp32[i]; float imag = K_fp32[2 * i] * sin_fp32[i] + K_fp32[2 * i + 1] * cos_fp32[i]; K_fp32[2 * i] = real; K_fp32[2 * i + 1] = imag; } convert_type_out(K_fp32, rK); } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void apply_rotary_contiguous(Tensor &rK_left, Tensor &rK_right, Tensor const &rCos, Tensor const &rSin) { CUTE_STATIC_ASSERT_V(rank(rK_left) == _1{}); CUTE_STATIC_ASSERT_V(rank(rK_right) == _1{}); CUTE_STATIC_ASSERT_V(rank(rCos) == _1{}); CUTE_STATIC_ASSERT_V(rank(rSin) == _1{}); CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rK_right)); CUTE_STATIC_ASSERT_V(size<0>(rK_left) == size<0>(rCos)); CUTE_STATIC_ASSERT_V(size<0>(rCos) == size<0>(rSin)); static_assert(decltype(size<0>(rCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 Tensor K_left_fp32 = make_tensor_like(rK_left); convert_type_out(rK_left, K_left_fp32); Tensor K_right_fp32 = make_tensor_like(rK_right); convert_type_out(rK_right, K_right_fp32); Tensor cos_fp32 = make_tensor_like(rCos); convert_type_out(rCos, cos_fp32); Tensor sin_fp32 = make_tensor_like(rSin); convert_type_out(rSin, sin_fp32); #pragma unroll for (int i = 0; i < size<0>(K_left_fp32); ++i) { float real = K_left_fp32[i] * cos_fp32[i] - K_right_fp32[i] * sin_fp32[i]; float imag = K_left_fp32[i] * sin_fp32[i] + K_right_fp32[i] * cos_fp32[i]; K_left_fp32[i] = real; K_right_fp32[i] = imag; } convert_type_out(K_left_fp32, rK_left); convert_type_out(K_right_fp32, rK_right); } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Rotary { static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); // We assume threads loading the same row are in the same warp. static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0, "kGmemThreadsPerRow must divide NumThreadsPerWarp"); using LayoutAtom = Layout, Int>, Stride, _1>>; using TiledCopyQK = decltype( make_tiled_copy(Copy_Atom, Element>{}, LayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using GmemTiledCopyRotary = decltype( make_tiled_copy(Copy_Atom, Element>{}, LayoutAtom{}, Layout>>{})); // Val layout, 4 or 8 vals per store using GmemTiledCopyRotaryCont = decltype( make_tiled_copy(Copy_Atom, Element>{}, LayoutAtom{}, Layout>>{})); // Val layout, 8 or 16 vals per store using ShapeRotary = cute::Shape; // (seqlen_ro, rotary_dim // 2) using StrideRotary = cute::Stride; using GmemThrCopyRotary = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0))); using GmemThrCopyRotaryCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0))); using TensortRcR = decltype(GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortRpR = decltype(make_tensor(make_shape(size<2>(TensortRcR{})))); using TensortRcRCont = decltype(GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortRpRCont = decltype(make_tensor(make_shape(size<2>(TensortRcRCont{})))); using TensormR = decltype(make_tensor( make_gmem_ptr((Element const*)nullptr), ShapeRotary{}, make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}))); using TensortRgR = decltype( GmemTiledCopyRotary{}.get_thread_slice(int(0)).partition_S(make_tensor( make_gmem_ptr((Element const*)nullptr), make_shape(Int{}, Int{}, int(0)), make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); using TensortRgRCont = decltype( GmemTiledCopyRotaryCont{}.get_thread_slice(int(0)).partition_S(make_tensor( make_gmem_ptr((Element const*)nullptr), make_shape(Int{}, Int{}, int(0)), make_stride(cute::conditional_return(_0{}, int64_t(0)), _1{}, cute::conditional_return(_0{}, int64_t(0)))))); GmemTiledCopyRotary gmem_tiled_copy_rotary; GmemTiledCopyRotaryCont gmem_tiled_copy_rotary_cont; bool const is_rotary_interleaved; int const rotary_dim; int const thread_idx; int const max_seqlen; GmemThrCopyRotary const gmem_thr_copy_rotary; GmemThrCopyRotaryCont const gmem_thr_copy_rotary_cont; TensortRpR tRpR; TensortRpRCont tRpRCont; TensormR mCos, mSin; TensortRgR tRgCos, tRgSin; TensortRgRCont tRgCosCont, tRgSinCont; CUTLASS_DEVICE Rotary(Element const* const ptr_rotary_cos, ShapeRotary const &shape_rotary, StrideRotary const &stride_rotary_cos_, Element const* const ptr_rotary_sin, StrideRotary const &stride_rotary_sin_, bool const is_rotary_interleaved, int const thread_idx, int const max_seqlen, int const start_idx) : is_rotary_interleaved(is_rotary_interleaved) , rotary_dim(get<1>(shape_rotary) * 2) , thread_idx(thread_idx) , max_seqlen(max_seqlen) , gmem_thr_copy_rotary(gmem_tiled_copy_rotary.get_thread_slice(thread_idx)) , gmem_thr_copy_rotary_cont(gmem_tiled_copy_rotary_cont.get_thread_slice(thread_idx)) { auto stride_rotary_cos = make_stride(cute::conditional_return(get<0>(stride_rotary_cos_), _0{}), get<1>(stride_rotary_cos_)); auto stride_rotary_sin = make_stride(cute::conditional_return(get<0>(stride_rotary_sin_), _0{}), get<1>(stride_rotary_sin_)); mCos = make_tensor(make_gmem_ptr(ptr_rotary_cos + start_idx * get<0>(stride_rotary_cos_)), shape_rotary, stride_rotary_cos); mSin = make_tensor(make_gmem_ptr(ptr_rotary_sin + start_idx * get<0>(stride_rotary_sin_)), shape_rotary, stride_rotary_sin); Tensor gCos = local_tile(mCos, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) Tensor gSin = local_tile(mSin, Shape, Int>{}, make_coord(_, _0{})); // (MN, K / 2, _) tRgCos = gmem_thr_copy_rotary.partition_S(gCos); tRgSin = gmem_thr_copy_rotary.partition_S(gSin); tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCos); tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSin); Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) Tensor tRcR = gmem_thr_copy_rotary.partition_D(cR); tRpR = make_tensor(make_shape(size<2>(tRcR))); #pragma unroll for (int k = 0; k < size(tRpR); ++k) { tRpR(k) = get<1>(tRcR(_0{}, _0{}, k)) < get<1>(shape_rotary); } Tensor tRcRCont = gmem_thr_copy_rotary_cont.partition_D(cR); tRpRCont = make_tensor(make_shape(size<2>(tRcRCont))); #pragma unroll for (int k = 0; k < size(tRpRCont); ++k) { tRpRCont(k) = get<1>(tRcRCont(_0{}, _0{}, k)) < get<1>(shape_rotary); } }; template CUTLASS_DEVICE auto load_cos_sin(int const block) { using GmemTiledCopyRo = std::conditional_t; auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); Tensor tRgCosCur = cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, block); Tensor tRgSinCur = cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, block); // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way Tensor tRrCos = make_tensor_like(tRgCosCur); Tensor tRrSin = make_tensor_like(tRgSinCur); Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); // If FixedPosition, only copy the first row as we only need the cos/sin for position cache_seqlens #pragma unroll for (int m = 0; m < (!FixedPosition ? size<1>(tRrCos) : 1); ++m) { if (get<0>(tRcR(_0{}, m, _0{})) < std::min(max_seqlen - block * kBlockMN, kBlockMN)) { #pragma unroll for (int k = 0; k < size<2>(tRrCos); ++k) { if (tRpRCur(k)) { cute::copy(GmemTiledCopyRo{}, tRgCosCur(_, m, k), tRrCos(_, m, k)); cute::copy(GmemTiledCopyRo{}, tRgSinCur(_, m, k), tRrSin(_, m, k)); } } } } return cute::make_tuple(tRrCos, tRrSin);; } template CUTLASS_DEVICE auto load_cos_sin_packgqa(int const block, cutlass::FastDivmod const &qhead_per_khead_divmod) { static constexpr int kGmemElemsPerLoadCur = kInterleaved ? kGmemElemsPerLoad / 2 : kGmemElemsPerLoad; using GmemTiledCopyRo = std::conditional_t; auto gmem_thr_copy_ro = cute::conditional_return(gmem_thr_copy_rotary, gmem_thr_copy_rotary_cont); Tensor tRpRCur = cute::conditional_return(tRpR, tRpRCont); // make_tensor_like, not make_fragment_like. If the row_stride is _0{} we want to keep it that way Tensor tRrCos = make_tensor_like(cute::conditional_return(tRgCos, tRgCosCont)(_, _, _, _0{})); Tensor tRrSin = make_tensor_like(cute::conditional_return(tRgSin, tRgSinCont)(_, _, _, _0{})); int const qhead_per_khead = qhead_per_khead_divmod.divisor; Tensor cR = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K / 2) Tensor tRcR = gmem_thr_copy_ro.partition_D(cR); // The main bottleneck here is actually instruction cache misses. // Similar to PagedKVNonTMA, it's expensive to compute the pointers. // We split the work among threads loading the same row, then __shfl_sync the pointers. static constexpr int NumPtrPerThread = cute::ceil_div(CUTE_STATIC_V(cute::size<1>(tRrCos)), kGmemThreadsPerRow); Tensor tPrCosPtr = make_tensor(Shape>{}); Tensor tPrSinPtr = make_tensor(Shape>{}); #pragma unroll for (int i = 0; i < NumPtrPerThread; ++i) { int const row = i * NumThreads + get<0>(tRcR(_0{}, thread_idx % kGmemThreadsPerRow, _0{})); int const idx = block * kBlockMN + row; int row_actual = qhead_per_khead_divmod.divide(idx); tPrCosPtr[i] = &mCos(row_actual, _0{}); tPrSinPtr[i] = &mSin(row_actual, _0{}); } #pragma unroll for (int m = 0; m < (!FixedPosition ? size<1>(tRgCos) : 1); ++m) { int const idx = block * kBlockMN + get<0>(tRcR(_0{}, m, _0{})); Element const* cos_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrCosPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); Element const* sin_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrSinPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); if (idx < max_seqlen * qhead_per_khead) { Tensor mCos_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(cos_ptr), Shape>{}), Shape>{}); Tensor mSin_copy = cute::tiled_divide(make_tensor(make_gmem_ptr(sin_ptr), Shape>{}), Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tRgCos); ++k) { int const ki = get<1>(tRcR(_0{}, _0{}, k)) / (kGmemElemsPerLoadCur); if (tRpRCur(k)) { cute::copy(GmemTiledCopyRo{}, mCos_copy(_, ki), tRrCos(_, m, k)); cute::copy(GmemTiledCopyRo{}, mSin_copy(_, ki), tRrSin(_, m, k)); } } } } return cute::make_tuple(tRrCos, tRrSin); } template CUTLASS_DEVICE void apply_Q_interleaved(TensorsQ &sQ, // (kBlockM, kHeadDim) TensortRrR const &tRrCos, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary TensortRrR const &tRrSin, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotary int const m_block, int const qhead_per_khead=1) { TiledCopyQK tiled_copy_q; auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); Tensor tQsQ = gmem_thr_copy_q.partition_S(sQ); Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); CUTE_STATIC_ASSERT_V(rank(tQsQ) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrCos)); CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrCos)); CUTE_STATIC_ASSERT_V(size<1>(tQsQ) == size<1>(tRrSin)); CUTE_STATIC_ASSERT_V(size<2>(tQsQ) == size<2>(tRrSin)); CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); static_assert(decltype(size<0>(tQsQ))::value == decltype(size<0>(tRrCos))::value * 2); static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 #pragma unroll for (int m = 0; m < size<1>(tQsQ); ++m) { if (get<0>(tQcQ(_0{}, m, _0{})) < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { #pragma unroll for (int k = 0; k < size<2>(tQsQ); ++k) { if (tRpR(k)) { Tensor rQ = make_fragment_like(tQsQ(_, m, k)); cute::copy(tiled_copy_q, tQsQ(_, m, k), rQ); apply_rotary_interleaved(rQ, tRrCos(_, m, k), tRrSin(_, m, k)); cute::copy(tiled_copy_q, rQ, tQsQ(_, m, k)); } } } } }; template CUTLASS_DEVICE void apply_Q_contiguous(TensorsQ &sQ, // (kBlockM, kHeadDim) TensortRrR const &tRrCosCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont TensortRrR const &tRrSinCont, // (kBlockM, kHeadDim / 2) split according to GmemThrCopyRotaryCont int const m_block, int const qhead_per_khead=1) { TiledCopyQK tiled_copy_q; auto gmem_thr_copy_q = tiled_copy_q.get_thread_slice(thread_idx); Tensor sQ_copy = cute::tiled_divide(sQ, Shape<_1, Int>{}); Tensor tQcQ = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); CUTE_STATIC_ASSERT_V(rank(tQcQ) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrCosCont)); CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrCosCont)); CUTE_STATIC_ASSERT_V(size<1>(tQcQ) == size<1>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<2>(tQcQ) == size<2>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tQcQ) == size<0>(tRrCosCont)); static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 #pragma unroll for (int m = 0; m < size<1>(tQcQ); ++m) { int const row = get<0>(tQcQ(_0{}, m, _0{})); if (row < std::min(max_seqlen * qhead_per_khead - m_block * kBlockMN, kBlockMN)) { #pragma unroll for (int k = 0; k < size<2>(tQcQ); ++k) { int const col = get<1>(tQcQ(_0{}, _0{}, k)); if (col < rotary_dim / 2) { int const col_idx_left = col / kGmemElemsPerLoad; int const col_idx_right = col / kGmemElemsPerLoad + rotary_dim / (2 * kGmemElemsPerLoad); Tensor rQ_left = make_fragment_like(sQ_copy(_, row, col_idx_left)); cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_left), rQ_left); Tensor rQ_right = make_fragment_like(rQ_left); cute::copy(tiled_copy_q, sQ_copy(_, row, col_idx_right), rQ_right); apply_rotary_contiguous(rQ_left, rQ_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); cute::copy(tiled_copy_q, rQ_left, sQ_copy(_, row, col_idx_left)); cute::copy(tiled_copy_q, rQ_right, sQ_copy(_, row, col_idx_right)); } } } } }; template CUTLASS_DEVICE void apply_K_interleaved(TensorsK const &sK, // (kBlockN, kHeadDim) TensorgK &gK, // (kBlockN, kHeadDim) TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV TensortRrR const &tRrCos, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary TensortRrR const &tRrSin, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotary TensorKPtr const &tPrKPtr, int const n_block) { TiledCopyQK tiled_copy_k; auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); Tensor tKsK = gmem_thr_copy_q.partition_S(sK); Tensor tKgK = gmem_thr_copy_q.partition_S(gK); Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); CUTE_STATIC_ASSERT_V(rank(tKsK) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrCos) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrSin) == _3{}); CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrCos)); CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrCos)); CUTE_STATIC_ASSERT_V(size<1>(tKsK) == size<1>(tRrSin)); CUTE_STATIC_ASSERT_V(size<2>(tKsK) == size<2>(tRrSin)); CUTE_STATIC_ASSERT_V(size<0>(tRrCos) == size<0>(tRrSin)); static_assert(decltype(size<0>(tKsK))::value == decltype(size<0>(tRrCos))::value * 2); static_assert(decltype(size<0>(tRrCos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } #pragma unroll for (int m = 0; m < size<1>(tKsK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); auto mK_cur_copy = [&] { if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); } else { return nullptr; } }(); if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { if (tKpK(k)) { Tensor rK = make_fragment_like(tKsK(_, m, k)); cute::copy(tiled_copy_k, tKsK(_, m, k), rK); if (tRpR(k)) { apply_rotary_interleaved(rK, tRrCos(_, m, k), tRrSin(_, m, k)); } if constexpr (!PagedKVNonTMA) { cute::copy(tiled_copy_k, rK, tKgK(_, m, k)); } else { int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; cute::copy(tiled_copy_k, rK, mK_cur_copy(_, ki)); } } } } } }; template CUTLASS_DEVICE void apply_K_contiguous(TensorsK const &sK, // (kBlockN, kHeadDim) TensorgK &gK, // (kBlockN, kHeadDim) TensorpK const &tKpK, // (kBlockN, kHeadDim) split according to ThrCopyKV TensortRrR const &tRrCosCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont TensortRrR const &tRrSinCont, // (kBlockN, kHeadDim/2) split according to GmemThrCopyRotaryCont TensorKPtr const &tPrKPtr, int const n_block, int const max_k) { TiledCopyQK tiled_copy_k; auto gmem_thr_copy_q = tiled_copy_k.get_thread_slice(thread_idx); Tensor sK_copy = cute::tiled_divide(sK, Shape<_1, Int>{}); Tensor gK_copy = cute::tiled_divide(gK, Shape<_1, Int>{}); Tensor tKcK = gmem_thr_copy_q.partition_S(cute::make_identity_tensor(Shape, Int>{})); CUTE_STATIC_ASSERT_V(rank(tKcK) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrCosCont) == _3{}); CUTE_STATIC_ASSERT_V(rank(tRrSinCont) == _3{}); CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrCosCont)); CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrCosCont)); CUTE_STATIC_ASSERT_V(size<1>(tKcK) == size<1>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<2>(tKcK) == size<2>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tRrCosCont) == size<0>(tRrSinCont)); CUTE_STATIC_ASSERT_V(size<0>(tKcK) == size<0>(tRrCosCont)); static_assert(decltype(size<0>(tRrCosCont))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 if constexpr (PagedKVNonTMA) { static_assert(decltype(size(tPrKPtr))::value == cute::ceil_div(size<1>(tKcK), kGmemThreadsPerRow)); } const int ro_dim_vec = rotary_dim / kGmemElemsPerLoad; const int non_ro_dim_vec = (max_k - rotary_dim) / kGmemElemsPerLoad; #pragma unroll for (int m = 0; m < size<1>(tKcK); ++m) { int const row = get<0>(tKcK(_0{}, m, _0{})); Tensor gK_cur_copy = [&] { if constexpr (PagedKVNonTMA) { Element* k_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrKPtr(m / kGmemThreadsPerRow)), (m % kGmemThreadsPerRow), kGmemThreadsPerRow)); Tensor mK_cur = make_tensor(make_gmem_ptr(k_ptr), Shape>{}); return cute::tiled_divide(mK_cur, Shape>{}); } else { return gK_copy(_, row, _); } }(); if (row < std::min(max_seqlen - n_block * kBlockMN, kBlockMN)) { #pragma unroll for (int k = 0; k < size<2>(tKcK); ++k) { if (tKpK(k)) { int const col = get<1>(tKcK(_0{}, _0{}, k)); bool rotate = col < rotary_dim / 2; int const col_idx_left = rotate ? col / kGmemElemsPerLoad : (col + rotary_dim / 2) / kGmemElemsPerLoad; int const col_idx_right = col_idx_left + (rotate ? ro_dim_vec / 2 : non_ro_dim_vec / 2); Tensor rK_left = make_fragment_like(sK_copy(_, row, col_idx_left)); cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_left), rK_left); Tensor rK_right = make_fragment_like(rK_left); cute::copy(tiled_copy_k, sK_copy(_, row, col_idx_right), rK_right); if (rotate) { apply_rotary_contiguous(rK_left, rK_right, tRrCosCont(_, m, k), tRrSinCont(_, m, k)); } cute::copy(tiled_copy_k, rK_left, gK_cur_copy(_, col_idx_left)); if (col_idx_right * kGmemElemsPerLoad < max_k) { cute::copy(tiled_copy_k, rK_right, gK_cur_copy(_, col_idx_right)); } } } } } }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash ================================================ FILE: hopper/seqlen.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once namespace flash { // We consolidate all the info related to sequence length here. This is so that we can do all // the gmem reads once at the beginning of each tile, rather than having to repeat these reads // to compute various things like n_block_min, n_block_max, etc. template struct SeqlenInfo { int const offset, offset_padded; int const seqlen; CUTLASS_DEVICE SeqlenInfo(int const bidb, int const seqlen_static, int const* const cu_seqlens, int const* const seqused) : offset(!Varlen || cu_seqlens == nullptr ? 0 : cu_seqlens[bidb]) , offset_padded(!Varlen || cu_seqlens == nullptr ? 0 : (cu_seqlens[bidb] + bidb * kBlock) / kBlock * kBlock) , seqlen(!Varlen ? seqlen_static : (seqused ? seqused[bidb] : (cu_seqlens ? cu_seqlens[bidb + 1] - cu_seqlens[bidb] : seqlen_static))) { } }; template struct SeqlenInfoQK { int const offset_q, offset_k, offset_q_padded; int const seqlen_q, seqlen_k; CUTLASS_DEVICE SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const seqused_q, int const* const seqused_k ) : offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) , offset_k(!Varlen || cu_seqlens_k == nullptr ? 0 : cu_seqlens_k[bidb]) // If varlen, the layout for dPSum, LSE_log2, and dQaccum is that we pad each sequence in the batch // by an extra kBlockM, so that the write for each sequence doesn't touch the next sequence. // Sequence i starts at cu_seqlens[i] + i * kBlockM and ends at cu_seqlens[i + 1] + i * kBlockM // However, the start must align to multiples of kBlockM. , offset_q_padded(!Varlen || cu_seqlens_q == nullptr ? 0 : (cu_seqlens_q[bidb] + bidb * kBlockM) / kBlockM * kBlockM) , seqlen_q(!Varlen ? seqlen_q_static : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) , seqlen_k(!Varlen ? seqlen_k_static : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static))) { } }; template struct SeqlenInfoQKNewK { static_assert(!(AppendKV && !Varlen), "AppendKV is only supported with Varlen"); int const leftpad_k; int const offset_q, offset_k, offset_k_new; int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary; CUTLASS_DEVICE SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, int const* const seqused_q, int const* const seqused_k, int const* const ptr_leftpad_k, int const* const seqlens_rotary ) : leftpad_k(ptr_leftpad_k ? ptr_leftpad_k[bidb] : 0) , offset_q(!Varlen || cu_seqlens_q == nullptr ? 0 : cu_seqlens_q[bidb]) , offset_k(!Varlen ? 0 : (cu_seqlens_k ? cu_seqlens_k[bidb] : 0) + leftpad_k) , offset_k_new(!AppendKV || cu_seqlens_k_new == nullptr ? 0 : cu_seqlens_k_new[bidb]) , seqlen_q(!Varlen ? seqlen_q_static : (seqused_q ? seqused_q[bidb] : (cu_seqlens_q ? cu_seqlens_q[bidb + 1] - cu_seqlens_q[bidb] : seqlen_q_static))) , seqlen_k_og(!Varlen ? seqlen_k_static : (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)) - leftpad_k) , seqlen_k_new(!AppendKV ? 0 : (cu_seqlens_k_new ? cu_seqlens_k_new[bidb + 1] - cu_seqlens_k_new[bidb] : shape_K_new_0)) , seqlen_k(!AppendKV ? seqlen_k_og : seqlen_k_og + seqlen_k_new) , seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb]) { } }; } // namespace flash ================================================ FILE: hopper/setup.py ================================================ # Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import sys import warnings import os import stat import re import shutil import ast from pathlib import Path from packaging.version import parse, Version import platform import sysconfig import tarfile import itertools from setuptools import setup, find_packages import subprocess import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME # with open("../README.md", "r", encoding="utf-8") as fh: with open("../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) PACKAGE_NAME = "flash_attn_3" BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" # ROCm specific settings USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" if USE_TRITON_ROCM: SKIP_CUDA_BUILD = True DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" DISABLE_VARLEN = os.getenv("FLASH_ATTENTION_DISABLE_VARLEN", "FALSE") == "TRUE" DISABLE_CLUSTER = os.getenv("FLASH_ATTENTION_DISABLE_CLUSTER", "FALSE") == "TRUE" DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" DISABLE_SM8x = os.getenv("FLASH_ATTENTION_DISABLE_SM80", "FALSE") == "TRUE" ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', # and pass "-gencode arch=compute_sm80,code=sm_80" to files ending in '_sm80.cu' from torch.utils.cpp_extension import ( IS_HIP_EXTENSION, COMMON_HIP_FLAGS, SUBPROCESS_DECODE_ARGS, IS_WINDOWS, get_cxx_compiler, _join_rocm_home, _join_cuda_home, _is_cuda_file, _maybe_write, ) def create_build_config_file(): CONFIG = { "build_flags": { "FLASHATTENTION_DISABLE_BACKWARD": DISABLE_BACKWARD, "FLASHATTENTION_DISABLE_SPLIT": DISABLE_SPLIT, "FLASHATTENTION_DISABLE_PAGEDKV": DISABLE_PAGEDKV, "FLASHATTENTION_DISABLE_APPENDKV": DISABLE_APPENDKV, "FLASHATTENTION_DISABLE_LOCAL": DISABLE_LOCAL, "FLASHATTENTION_DISABLE_SOFTCAP": DISABLE_SOFTCAP, "FLASHATTENTION_DISABLE_PACKGQA": DISABLE_PACKGQA, "FLASHATTENTION_DISABLE_FP16": DISABLE_FP16, "FLASHATTENTION_DISABLE_FP8": DISABLE_FP8, "FLASHATTENTION_DISABLE_VARLEN": DISABLE_VARLEN, "FLASHATTENTION_DISABLE_CLUSTER": DISABLE_CLUSTER, "FLASHATTENTION_DISABLE_HDIM64": DISABLE_HDIM64, "FLASHATTENTION_DISABLE_HDIM96": DISABLE_HDIM96, "FLASHATTENTION_DISABLE_HDIM128": DISABLE_HDIM128, "FLASHATTENTION_DISABLE_HDIM192": DISABLE_HDIM192, "FLASHATTENTION_DISABLE_HDIM256": DISABLE_HDIM256, "FLASHATTENTION_DISABLE_SM8x": DISABLE_SM8x, "FLASHATTENTION_ENABLE_VCOLMAJOR": ENABLE_VCOLMAJOR, "FLASH_ATTENTION_DISABLE_HDIMDIFF64": DISABLE_HDIMDIFF64, "FLASH_ATTENTION_DISABLE_HDIMDIFF192": DISABLE_HDIMDIFF192, } } with open("flash_attn_config.py", "w") as f: f.write("# Auto-generated by flash attention 3 setup.py\n") f.write(f"CONFIG = {repr(CONFIG)}\n") f.write("\n") f.write("def show():\n") f.write(" from pprint import pprint\n") f.write(" pprint(CONFIG)\n") f.write("\n") def _write_ninja_file(path, cflags, post_cflags, cuda_cflags, cuda_post_cflags, cuda_dlink_post_cflags, sources, objects, ldflags, library_target, with_cuda, **kwargs, # kwargs (ignored) to absorb new flags in torch.utils.cpp_extension ) -> None: r"""Write a ninja file that does the desired compiling and linking. `path`: Where to write this file `cflags`: list of flags to pass to $cxx. Can be None. `post_cflags`: list of flags to append to the $cxx invocation. Can be None. `cuda_cflags`: list of flags to pass to $nvcc. Can be None. `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None. `sources`: list of paths to source files `objects`: list of desired paths to objects, one per source. `ldflags`: list of flags to pass to linker. Can be None. `library_target`: Name of the output library. Can be None; in that case, we do no linking. `with_cuda`: If we should be compiling with CUDA. """ def sanitize_flags(flags): if flags is None: return [] else: return [flag.strip() for flag in flags] cflags = sanitize_flags(cflags) post_cflags = sanitize_flags(post_cflags) cuda_cflags = sanitize_flags(cuda_cflags) cuda_post_cflags = sanitize_flags(cuda_post_cflags) cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags) ldflags = sanitize_flags(ldflags) # Sanity checks... assert len(sources) == len(objects) assert len(sources) > 0 compiler = get_cxx_compiler() # Version 1.3 is required for the `deps` directive. config = ['ninja_required_version = 1.3'] config.append(f'cxx = {compiler}') if with_cuda or cuda_dlink_post_cflags: if IS_HIP_EXTENSION: nvcc = _join_rocm_home('bin', 'hipcc') else: nvcc = _join_cuda_home('bin', 'nvcc') if "PYTORCH_NVCC" in os.environ: nvcc_from_env = os.getenv("PYTORCH_NVCC") # user can set nvcc compiler with ccache using the environment variable here else: nvcc_from_env = nvcc config.append(f'nvcc_from_env = {nvcc_from_env}') config.append(f'nvcc = {nvcc}') if IS_HIP_EXTENSION: post_cflags = COMMON_HIP_FLAGS + post_cflags flags = [f'cflags = {" ".join(cflags)}'] flags.append(f'post_cflags = {" ".join(post_cflags)}') if with_cuda: flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') cuda_post_cflags_sm80 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_80,code=sm_80' for s in cuda_post_cflags] flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') # Turn into absolute paths so we can emit them into the ninja build # file wherever it is. sources = [os.path.abspath(file) for file in sources] # See https://ninja-build.org/build.ninja.html for reference. compile_rule = ['rule compile'] if IS_WINDOWS: compile_rule.append( ' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags') compile_rule.append(' deps = msvc') else: compile_rule.append( ' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags') compile_rule.append(' depfile = $out.d') compile_rule.append(' deps = gcc') if with_cuda: cuda_compile_rule = ['rule cuda_compile'] nvcc_gendeps = '' # --generate-dependencies-with-compile is not supported by ROCm # Nvcc flag `--generate-dependencies-with-compile` is not supported by sccache, which may increase build time. if torch.version.cuda is not None and os.getenv('TORCH_EXTENSION_SKIP_NVCC_GEN_DEPENDENCIES', '0') != '1': cuda_compile_rule.append(' depfile = $out.d') cuda_compile_rule.append(' deps = gcc') # Note: non-system deps with nvcc are only supported # on Linux so use --generate-dependencies-with-compile # to make this work on Windows too. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' ] cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' ] cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') # Emit one build rule per source to enable incremental build. build = [] for source_file, object_file in zip(sources, objects): is_cuda_source = _is_cuda_file(source_file) and with_cuda if is_cuda_source: if source_file.endswith('_sm90.cu'): rule = 'cuda_compile' elif source_file.endswith('_sm80.cu'): rule = 'cuda_compile_sm80' elif source_file.endswith('_sm100.cu'): rule = 'cuda_compile_sm100' else: rule = 'cuda_compile_sm80_sm90' else: rule = 'compile' if IS_WINDOWS: source_file = source_file.replace(':', '$:') object_file = object_file.replace(':', '$:') source_file = source_file.replace(" ", "$ ") object_file = object_file.replace(" ", "$ ") build.append(f'build {object_file}: {rule} {source_file}') if cuda_dlink_post_cflags: devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o') devlink_rule = ['rule cuda_devlink'] devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags') devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}'] objects += [devlink_out] else: devlink_rule, devlink = [], [] if library_target is not None: link_rule = ['rule link'] if IS_WINDOWS: cl_paths = subprocess.check_output(['where', 'cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n') if len(cl_paths) >= 1: cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:') else: raise RuntimeError("MSVC is required to load C++ extensions") link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out') else: link_rule.append(' command = $cxx $in $ldflags -o $out') link = [f'build {library_target}: link {" ".join(objects)}'] default = [f'default {library_target}'] else: link_rule, link, default = [], [], [] # 'Blocks' should be separated by newlines, for visual benefit. blocks = [config, flags, compile_rule] if with_cuda: blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] blocks += [devlink_rule, link_rule, build, devlink, link, default] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file content += "\n" _maybe_write(path, content) # Monkey patching torch.utils.cpp_extension._write_ninja_file = _write_ninja_file def get_platform(): """ Returns the platform name as used in wheel filenames. """ if sys.platform.startswith("linux"): return "linux_x86_64" elif sys.platform == "darwin": mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) return f"macosx_{mac_version}_x86_64" elif sys.platform == "win32": return "win_amd64" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_version def check_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary # in that case. warnings.warn( f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " "only images whose names contain 'devel' will provide nvcc." ) # Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py def check_env_flag(name: str, default: str = "") -> bool: return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"] # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py def is_offline_build() -> bool: """ Downstream projects and distributions which bootstrap their own dependencies from scratch and run builds in offline sandboxes may set `FLASH_ATTENTION_OFFLINE_BUILD` in the build environment to prevent any attempts at downloading pinned dependencies from the internet or at using dependencies vendored in-tree. Dependencies must be defined using respective search paths (cf. `syspath_var_name` in `Package`). Missing dependencies lead to an early abortion. Dependencies' compatibility is not verified. Note that this flag isn't tested by the CI and does not provide any guarantees. """ return check_env_flag("FLASH_ATTENTION_OFFLINE_BUILD", "") # Copied from https://github.com/triton-lang/triton/blob/main/python/setup.py def get_flashattn_cache_path(): user_home = os.getenv("FLASH_ATTENTION_HOME") if not user_home: user_home = os.getenv("HOME") or os.getenv("USERPROFILE") or os.getenv("HOMEPATH") or None if not user_home: raise RuntimeError("Could not find user home directory") return os.path.join(user_home, ".flashattn") def open_url(url): user_agent = 'Mozilla/5.0 (X11; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/119.0' headers = { 'User-Agent': user_agent, } request = urllib.request.Request(url, None, headers) # Set timeout to 300 seconds to prevent the request from hanging forever. return urllib.request.urlopen(request, timeout=300) def download_and_copy(name, src_func, dst_path, version, url_func): if is_offline_build(): return flashattn_cache_path = get_flashattn_cache_path() base_dir = os.path.dirname(__file__) system = platform.system() arch = platform.machine() arch = {"arm64": "aarch64"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) src_path = src_func(supported[system], arch, version) tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if download: print(f'downloading and extracting {url} ...') file = tarfile.open(fileobj=open_url(url), mode="r|*") file.extractall(path=tmp_path) os.makedirs(os.path.split(dst_path)[0], exist_ok=True) print(f'copy {src_path} to {dst_path} ...') if os.path.isdir(src_path): shutil.copytree(src_path, dst_path, dirs_exist_ok=True) else: shutil.copy(src_path, dst_path) def nvcc_threads_args(): nvcc_threads = os.getenv("NVCC_THREADS") or "2" return ["--threads", nvcc_threads] # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.93"} exe_extension = sysconfig.get_config_var("EXE") cmdclass = {} ext_modules = [] # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if not USE_TRITON_ROCM: subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) create_build_config_file() check_if_cuda_home_none(PACKAGE_NAME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") elif bare_metal_version >= Version("13.0"): # CUDA 13.0+ uses system nvcc and CCCL headers are in /usr/local/cuda/include/cccl/ cccl_include = os.path.join(CUDA_HOME, "include", "cccl") for env_var in ["CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"]: current = os.environ.get(env_var, "") os.environ[env_var] = cccl_include + (":" + current if current else "") # ptxas 12.8 gives the best perf currently # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. # For CUDA 13.0+, use system nvcc instead of downloading CUDA 12.x toolchain if bare_metal_version >= Version("12.3") and bare_metal_version < Version("13.0") and bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", dst_path="bin", version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( name="ptxas", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", dst_path="bin", version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( name="ptxas", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", dst_path="nvvm/bin", version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) base_dir = os.path.dirname(__file__) ctk_path_new = os.path.abspath(os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin")) nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) cc_flag = [] cc_flag.append("-gencode") cc_flag.append("arch=compute_90a,code=sm_90a") # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True repo_dir = Path(this_dir).parent cutlass_dir = repo_dir / "csrc" / "cutlass" feature_args = ( [] + (["-DFLASHATTENTION_DISABLE_BACKWARD"] if DISABLE_BACKWARD else []) + (["-DFLASHATTENTION_DISABLE_PAGEDKV"] if DISABLE_PAGEDKV else []) + (["-DFLASHATTENTION_DISABLE_SPLIT"] if DISABLE_SPLIT else []) + (["-DFLASHATTENTION_DISABLE_APPENDKV"] if DISABLE_APPENDKV else []) + (["-DFLASHATTENTION_DISABLE_LOCAL"] if DISABLE_LOCAL else []) + (["-DFLASHATTENTION_DISABLE_SOFTCAP"] if DISABLE_SOFTCAP else []) + (["-DFLASHATTENTION_DISABLE_PACKGQA"] if DISABLE_PACKGQA else []) + (["-DFLASHATTENTION_DISABLE_FP16"] if DISABLE_FP16 else []) + (["-DFLASHATTENTION_DISABLE_FP8"] if DISABLE_FP8 else []) + (["-DFLASHATTENTION_DISABLE_VARLEN"] if DISABLE_VARLEN else []) + (["-DFLASHATTENTION_DISABLE_CLUSTER"] if DISABLE_CLUSTER else []) + (["-DFLASHATTENTION_DISABLE_HDIM64"] if DISABLE_HDIM64 else []) + (["-DFLASHATTENTION_DISABLE_HDIM96"] if DISABLE_HDIM96 else []) + (["-DFLASHATTENTION_DISABLE_HDIM128"] if DISABLE_HDIM128 else []) + (["-DFLASHATTENTION_DISABLE_HDIM192"] if DISABLE_HDIM192 else []) + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] + ([64] if not DISABLE_HDIM64 else []) + ([96] if not DISABLE_HDIM96 else []) + ([128] if not DISABLE_HDIM128 else []) + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) # build will now explode with this compilation grouping given all our templating # HEAD_DIMENSIONS_FWD = ["all", "diff"] HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD HEAD_DIMENSIONS_DIFF64_FWD = ( [] + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + (["64_512"] if not DISABLE_HDIMDIFF64 else []) ) HEAD_DIMENSIONS_DIFF192_FWD = ( [] + (["192_128"] if not DISABLE_HDIMDIFF192 else []) ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) SOFTCAP = [""] + (["_softcap"] if not DISABLE_SOFTCAP else []) SOFTCAP_ALL = [""] if DISABLE_SOFTCAP else ["_softcapall"] PACKGQA = [""] + (["_packgqa"] if not DISABLE_PACKGQA else []) # We already always hard-code PackGQA=true for Sm8x sources_fwd_sm80 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}_sm80.cu" for hdim, dtype, split, paged, softcap in itertools.product(HEAD_DIMENSIONS_FWD_SM80, DTYPE_FWD_SM80, SPLIT, PAGEDKV, SOFTCAP_ALL)] # We already always hard-code PackGQA=true for Sm9x if PagedKV or Split sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] if not DISABLE_HDIMDIFF64: sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] if not DISABLE_HDIMDIFF192: sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP_ALL)] if DISABLE_BACKWARD: sources_bwd_sm90 = [] sources_bwd_sm80 = [] # Choose between flash_api.cpp and flash_api_stable.cpp based on torch version torch_version = parse(torch.__version__) target_version = parse("2.9.0.dev20250830") stable_args = [] if torch_version >= target_version: flash_api_source = "flash_api_stable.cpp" stable_args = ["-DTORCH_TARGET_VERSION=0x0209000000000000"] # Targets minimum runtime version torch 2.9.0 else: flash_api_source = "flash_api.cpp" sources = ( [flash_api_source] + (sources_fwd_sm80 if not DISABLE_SM8x else []) + sources_fwd_sm90 + (sources_bwd_sm80 if not DISABLE_SM8x else []) + sources_bwd_sm90 ) if not DISABLE_SPLIT: sources += ["flash_fwd_combine.cu"] sources += ["flash_prepare_scheduler.cu"] nvcc_flags = [ "-O3", "-std=c++17", "--ftemplate-backtrace-limit=0", # To debug template code "--use_fast_math", # "--keep", # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage", # printing out number of registers "--resource-usage", # printing out number of registers # f"--split-compile={os.getenv('NVCC_THREADS', '4')}", # split-compile is faster "-lineinfo", # TODO: disable this for release to reduce binary size "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", # Necessary for the WGMMA shapes that we use "-DCUTLASS_ENABLE_GDC_FOR_SM90", # For PDL "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging "-DNDEBUG", # Important, otherwise performance is severely impacted ] if get_platform() == "win_amd64": nvcc_flags.extend( [ "-D_USE_MATH_DEFINES", # for M_LN2 "-Xcompiler=/Zc:__cplusplus", # sets __cplusplus correctly, CUTLASS_CONSTEXPR_IF_CXX17 needed for cutlass::gcd ] ) include_dirs = [ Path(this_dir), cutlass_dir / "include", ] ext_modules.append( CUDAExtension( name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, py_limited_api=True, ) ) def get_package_version(): with open(Path(this_dir) / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") if local_version: return f"{public_version}+{local_version}" else: return str(public_version) def get_wheel_url(): # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) torch_version_raw = parse(torch.__version__) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 # to save CI time. Minor versions should be compatible. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() package_version = get_package_version() # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename) return wheel_url, wheel_filename class CachedWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all installs). We use the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ def run(self): if FORCE_BUILD: return super().run() wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) # Make the archive # Lifted from the root wheel processing command # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 if not os.path.exists(self.dist_dir): os.makedirs(self.dist_dir) impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) shutil.move(wheel_filename, wheel_path) except urllib.error.HTTPError: print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() setup( name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( exclude=( "build", "csrc", "include", "tests", "dist", "docs", "benchmarks", ) ), py_modules=["flash_attn_interface", "flash_attn_config"], description="FlashAttention-3", long_description=long_description, long_description_content_type="text/markdown", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: Unix", ], ext_modules=ext_modules, cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} if ext_modules else { "bdist_wheel": CachedWheelsCommand, }, python_requires=">=3.8", install_requires=[ "torch", "einops", "packaging", "ninja", ], options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) ================================================ FILE: hopper/sm90_pipeline_no_cluster.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include namespace cutlass { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// // As of Cutlass v3.6.0, if size(ClusterShape) == 1, PipelineTmaAsync has all threads // signaling the barrier during consumer_release. This causes a perf regression in FA3 // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // // Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 template > class PipelineTmaAsyncNoCluster: public Base { public: using FullBarrier = typename Base::FullBarrier; using EmptyBarrier = typename Base::EmptyBarrier; static constexpr uint32_t Stages = Stages_; using PipelineState = typename Base::PipelineState; using SharedStorage = typename Base::SharedStorage; using ThreadCategory = typename Base::ThreadCategory; using Params = typename Base::Params; static CUTLASS_DEVICE void init_barriers(SharedStorage& storage, Params params) { int warp_idx = canonical_warp_idx_sync(); bool is_initializing_warp = (warp_idx == 0); if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; uint32_t const num_consumer_warpgroups_per_cluster = (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } cutlass::arch::fence_barrier_init(); } template CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) : Base(storage, params, make_shape(_1{}, _1{}, _1{}) /*cluster_shape*/, cute::false_type{} /*init_barriers*/, cute::false_type{} /*init_masks*/) , empty_barrier_ptr_(&storage.empty_barrier_[0]) { int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); static_assert(cute::is_same_v || cute::is_same_v); static_assert(cute::is_same_v || cute::is_same_v); if constexpr (cute::is_same_v) { init_barriers(storage, params); } } // Constructor template CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape) : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } template CUTLASS_DEVICE PipelineTmaAsyncNoCluster(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) : PipelineTmaAsyncNoCluster(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } CUTLASS_DEVICE void consumer_release(PipelineState state) { consumer_release(state.index()); } private: EmptyBarrier* const empty_barrier_ptr_ = nullptr; // Consumer signalling Producer of completion // Ensures all blocks in the Same Row and Column get notifed. CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip = false) { empty_barrier_ptr_[stage].arrive(0 /*dst_blockid_*/, uint32_t(threadIdx.x % cutlass::NumThreadsPerWarpGroup == 0) & (!skip) /*is_signaling_thread*/); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // end namespace cutlass ================================================ FILE: hopper/softmax.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include "utils.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template __device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ni++) { #pragma unroll for (int mi = 0; mi < size<0>(tensor); mi++) { summary(mi) = zero_init && ni == 0 ? tensor(mi, ni) : op(summary(mi), tensor(mi, ni)); } } } template __device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++) { dst(i) = Allreduce<4>::run(src(i), op); } } template __device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template __device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template __device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; thread_reduce_(tensor, sum, sum_op); if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } } // Apply the exp to all the elements. template __forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { // For FP8, we can subtract max by 8.0 so that the value after exp2 is in the range of [0, 256]. // This lets us use more of the FP8 range (instead of just [0, 1]) to reduce underflow. static constexpr float max_offset = float(Max_offset); // We can only template on int, not float static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. const float max_scaled = Check_inf ? (max(mi) == -INFINITY ? 0.f : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset) : (!Scale_max ? max(mi) : max(mi) * scale) - max_offset; #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - // max * log_2(e)). This allows the compiler to use the ffma // instruction instead of fadd and fmul separately. tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; float const softmax_scale_log2; CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_log2(softmax_scale_log2_) {}; template __forceinline__ __device__ TensorT max_get_scale(Tensor0 &acc_s) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); TensorT scores_scale; if constexpr (Is_first) { flash::template reduce_max(scores, row_max); cute::fill(scores_scale, 1.f); } else { Tensor scores_max_prev = make_fragment_like(row_max); cute::copy(row_max, scores_max_prev); flash::template reduce_max(scores, row_max); #pragma unroll for (int mi = 0; mi < size(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale(mi); } } return scores_scale; }; template __forceinline__ __device__ void online_softmax(Tensor0 &acc_s) { // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(CUTE_STATIC_V(size<0>(scores)) == kNRows); flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. flash::reduce_sum(scores, row_sum); }; __forceinline__ __device__ TensorT finalize(float const final_scale=1.f) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT scores_scale; #pragma unroll for (int mi = 0; mi < size(row_sum); ++mi) { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; scores_scale(mi) = inv_sum * final_scale; // For FP8, we might have scaled the output of exp by 2**8 so we need to divide sum by that amount. if constexpr (Max_offset != 0) { static constexpr float sum_scale = 1.f / float(1 << Max_offset); sum *= sum_scale; } row_sum(mi) = (sum == 0.f || sum != sum) ? -INFINITY : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); } return scores_scale; }; template __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(CUTE_STATIC_V(size<0>(acc_o_rowcol)) == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } } }; }; } // namespace flash ================================================ FILE: hopper/static_switch.h ================================================ // Inspired by // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once /// @param COND - a boolean expression to switch by /// @param CONST_NAME - a name given for the constexpr bool variable. /// @param ... - code to execute for true and false /// /// Usage: /// ``` /// BOOL_SWITCH(flag, BoolConst, [&] { /// some_function(...); /// }); /// ``` // #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ constexpr static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #ifdef FLASHATTENTION_DISABLE_LOCAL #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ [&] { \ constexpr static bool LOCAL_CONST_NAME = false; \ if (CAUSAL_COND) { \ constexpr static bool CAUSAL_CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CAUSAL_CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #else #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \ [&] { \ if (CAUSAL_COND) { \ constexpr static bool CAUSAL_CONST_NAME = true; \ constexpr static bool LOCAL_CONST_NAME = false; \ return __VA_ARGS__(); \ } else if (LOCAL_COND) { \ constexpr static bool CAUSAL_CONST_NAME = false; \ constexpr static bool LOCAL_CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ constexpr static bool CAUSAL_CONST_NAME = false; \ constexpr static bool LOCAL_CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }() #endif #ifdef FLASHATTENTION_DISABLE_SOFTCAP #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define SOFTCAP_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_PAGEDKV #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define PAGEDKV_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_SPLIT #define SPLIT_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define SPLIT_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_APPENDKV #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define APPENDKV_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_PACKGQA #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define PACKGQA_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_VARLEN #define VARLEN_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define VARLEN_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_CLUSTER #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define CLUSTER_SWITCH BOOL_SWITCH #endif #ifdef FLASHATTENTION_DISABLE_SM8x #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ [&] { \ constexpr static int ARCH_NAME = 90; \ return __VA_ARGS__(); \ }() #else #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \ [&] { \ if (ARCH == 86 || ARCH == 89) { \ constexpr static int ARCH_NAME = 86; \ return __VA_ARGS__(); \ } else if (ARCH < 90) { \ constexpr static int ARCH_NAME = 80; \ return __VA_ARGS__(); \ } else { \ constexpr static int ARCH_NAME = 90; \ return __VA_ARGS__(); \ } \ }() #endif #ifndef FLASHATTENTION_ENABLE_VCOLMAJOR #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \ [&] { \ constexpr static bool CONST_NAME = false; \ return __VA_ARGS__(); \ }() #else #define VCOLMAJOR_SWITCH BOOL_SWITCH #endif #define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM == 64) { \ constexpr static int kHeadSize = 64; \ return __VA_ARGS__(); \ } else if (HEADDIM == 96) { \ constexpr static int kHeadSize = 96; \ return __VA_ARGS__(); \ } else if (HEADDIM == 128) { \ constexpr static int kHeadSize = 128; \ return __VA_ARGS__(); \ } else if (HEADDIM == 96) { \ constexpr static int kHeadSize = 96; \ return __VA_ARGS__(); \ } else if (HEADDIM == 256) { \ constexpr static int kHeadSize = 256; \ return __VA_ARGS__(); \ } \ }() #define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ [&] { \ if (VALUE <= 1) { \ constexpr static int CONST_NAME = 1; \ return __VA_ARGS__(); \ } else if (VALUE <= 2) { \ constexpr static int CONST_NAME = 2; \ return __VA_ARGS__(); \ } else if (VALUE <= 4) { \ constexpr static int CONST_NAME = 4; \ return __VA_ARGS__(); \ } else if (VALUE <= 8) { \ constexpr static int CONST_NAME = 8; \ return __VA_ARGS__(); \ } else if (VALUE <= 16) { \ constexpr static int CONST_NAME = 16; \ return __VA_ARGS__(); \ } else { \ constexpr static int CONST_NAME = 32; \ return __VA_ARGS__(); \ } \ }() ================================================ FILE: hopper/test_attn_kvcache.py ================================================ import pytest from einops import rearrange, repeat import torch import flash_attn import flash_attn_interface import itertools import math import time def construct_local_mask( seqlen_q, seqlen_k, window_size=(-1, -1), # -1 means infinite window size query_padding_mask=None, key_padding_mask=None, device=None, key_leftpad=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), col_idx < row_idx + sk - sq - window_size[0], ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if softcap > 0: scores = scores / softcap scores = scores.tanh() scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("num_requests", [1, 4]) @pytest.mark.parametrize("query_seqlen", [1, 8, 120]) @pytest.mark.parametrize("context_seqlen", [1024, 3131, 4224]) @pytest.mark.parametrize("headdim", [64, 128, 256]) @pytest.mark.parametrize("gqa_parallel", [False, True]) @pytest.mark.parametrize( "nheads_kv, gqa_ratio", [ (1, 1), (2, 5), (3, 3), (1, 32), (5, 7), (8, 1), (1, 16), (12, 4), (8, 2), ], ) def test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel): device = "cuda" num_caches = num_requests cache_seqlen = context_seqlen nheads_q = nheads_kv * gqa_ratio k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") torch.cuda.synchronize() out_ref, _ = attention_ref( q, k_cache, v_cache, causal=causal, ) out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, # cache_batch_idx=cache_idxs, causal=causal, num_splits=1, return_softmax_lse=True, gqa_parallel=gqa_parallel ) torch.cuda.synchronize() assert ((out_ref - out_fa3).abs().max().item() <= 4e-3) assert ((out_ref - out_fa3).abs().mean().item() <= 2e-4) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("num_requests", [1, 3]) @pytest.mark.parametrize("query_seqlen", [1, 8, 120]) @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) @pytest.mark.parametrize("headdim", [64, 128, 256]) @pytest.mark.parametrize("gqa_parallel", [True, False]) @pytest.mark.parametrize( "nheads_kv, gqa_ratio", [ (1, 1), (2, 5), (3, 3), (1, 32), (5, 7), (8, 1), (1, 16), (12, 4), (8, 2), ], ) def test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, gqa_parallel): device = "cuda" num_caches = num_requests cache_seqlen = context_seqlen nheads_q = nheads_kv * gqa_ratio k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) q = q.to(torch.float8_e4m3fn) k_cache = k_cache.to(torch.float8_e4m3fn) v_cache = v_cache.to(torch.float8_e4m3fn) # cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] cache_seqlens = torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") torch.cuda.synchronize() out_ref, _ = attention_ref( q, k_cache, v_cache, causal=causal, ) descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda') out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, # cache_batch_idx=cache_idxs, causal=causal, num_splits=1, return_softmax_lse=True, gqa_parallel=gqa_parallel, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) torch.cuda.synchronize() assert ((out_ref - out_fa3).abs().max().item() <= 4e-2) assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("use_heuristic_only", [True]) # @pytest.mark.parametrize("use_heuristic_only", [False]) @pytest.mark.parametrize("causal", [True, False]) # @pytest.mark.parametrize("num_requests", [1, 4, 16]) @pytest.mark.parametrize("num_requests", [1, 3]) # @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128]) @pytest.mark.parametrize("query_seqlen", [1, 8, 25]) # @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536]) @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) @pytest.mark.parametrize("headdim", [64, 128, 256]) @pytest.mark.parametrize("cache_seqlen_rand", [True, False]) @pytest.mark.parametrize("gqa_parallel", [True, False]) @pytest.mark.parametrize( "nheads_kv, gqa_ratio", [ (1, 1), (4, 1), (2, 2), (3, 3), (4, 4), (2, 5), (3, 9), (1, 16), (1, 32), ], ) def test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype): device = "cuda" num_caches = 16 if context_seqlen <= 65536: cache_seqlen = 65536 else: cache_seqlen = context_seqlen nheads_q = nheads_kv * gqa_ratio if use_heuristic_only: max_splits = 1 else: max_splits = 128 k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) q = q.to(dtype) k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") torch.cuda.synchronize() out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, num_splits=1, return_softmax_lse=True, gqa_parallel=False ) # i=0 case is with num splits heuristic for i in range(0, max_splits+1): out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, num_splits=i, return_softmax_lse=True, gqa_parallel=gqa_parallel, max_seqlen_k_hint=context_seqlen ) torch.cuda.synchronize() print ('output-ref', i, out_ref) print ('output-fa3',i, out_fa3) print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item()) print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item()) print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item()) print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item()) if cache_seqlen_rand: assert ((out_ref - out_fa3).abs().max().item() <= 1e-2) assert ((out_ref - out_fa3).abs().mean().item() <= 1e-3) else: assert ((out_ref - out_fa3).abs().max().item() <= 2e-3) assert ((out_ref - out_fa3).abs().mean().item() <= 1e-4) lse_max_ref = lse_ref.abs().max().item() lse_mean_ref = lse_ref.abs().mean().item() lse_max_fa3 = lse_fa3.abs().max().item() lse_mean_fa3 = lse_fa3.abs().mean().item() lse_max_diff = (lse_ref - lse_fa3).abs().max().item() lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item() assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3) assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("use_heuristic_only", [True]) # @pytest.mark.parametrize("use_heuristic_only", [False]) @pytest.mark.parametrize("causal", [True, False]) # @pytest.mark.parametrize("num_requests", [1, 4, 16]) @pytest.mark.parametrize("num_requests", [1, 3]) # @pytest.mark.parametrize("query_seqlen", [1, 16, 32, 128]) @pytest.mark.parametrize("query_seqlen", [1, 8, 25]) # @pytest.mark.parametrize("context_seqlen", [4096, 16384, 65536]) @pytest.mark.parametrize("context_seqlen", [1600, 4000, 5555]) @pytest.mark.parametrize("headdim", [64, 128, 256]) @pytest.mark.parametrize("cache_seqlen_rand", [True, False]) @pytest.mark.parametrize("gqa_parallel", [True, False]) @pytest.mark.parametrize( "nheads_kv, gqa_ratio", [ (1, 1), (4, 1), (2, 2), (3, 3), (4, 4), (2, 5), (3, 9), (1, 16), (1, 32), ], ) def test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_requests, query_seqlen, context_seqlen, headdim, causal, use_heuristic_only, cache_seqlen_rand, gqa_parallel, dtype): device = "cuda" num_caches = 16 if context_seqlen <= 65536: cache_seqlen = 65536 else: cache_seqlen = context_seqlen nheads_q = nheads_kv * gqa_ratio if use_heuristic_only: max_splits = 1 else: max_splits = 128 k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=torch.bfloat16 ) q = torch.randn((num_requests, query_seqlen, nheads_q, headdim), device="cuda", dtype=torch.bfloat16) q = q.to(dtype) k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) cache_idxs = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[:num_requests] cache_seqlens = torch.randint(1, context_seqlen-1, (num_requests,), dtype=torch.int32).to(device) if cache_seqlen_rand else torch.tensor([context_seqlen] * num_requests, dtype=torch.int32, device="cuda") torch.cuda.synchronize() descale_q = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_k = torch.tensor([1.0], dtype=torch.float32, device='cuda') descale_v = torch.tensor([1.0], dtype=torch.float32, device='cuda') out_ref, lse_ref = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, num_splits=1, return_softmax_lse=True, gqa_parallel=False, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) # i=0 case is with num splits heuristic for i in range(0, max_splits+1): out_fa3, lse_fa3 = flash_attn_interface.flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens, cache_batch_idx=cache_idxs, causal=causal, num_splits=i, return_softmax_lse=True, gqa_parallel=gqa_parallel, max_seqlen_k_hint=context_seqlen, descale_q=descale_q, descale_k=descale_k, descale_v=descale_v ) torch.cuda.synchronize() print ('output-ref', i, out_ref) print ('output-fa3',i, out_fa3) print ('output-max-diff', i, context_seqlen, (out_ref - out_fa3).abs().max().item()) print ('output-mean-diff',i, context_seqlen, (out_ref - out_fa3).abs().mean().item()) print ('lse-max-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().max().item()) print ('lse-mean-diff',i, context_seqlen, (lse_ref - lse_fa3).abs().mean().item()) if cache_seqlen_rand: assert ((out_ref - out_fa3).abs().max().item() <= 1e-1) assert ((out_ref - out_fa3).abs().mean().item() <= 1e-2) else: assert ((out_ref - out_fa3).abs().max().item() <= 2e-2) assert ((out_ref - out_fa3).abs().mean().item() <= 2e-3) lse_max_ref = lse_ref.abs().max().item() lse_mean_ref = lse_ref.abs().mean().item() lse_max_fa3 = lse_fa3.abs().max().item() lse_mean_fa3 = lse_fa3.abs().mean().item() lse_max_diff = (lse_ref - lse_fa3).abs().max().item() lse_mean_diff = (lse_ref - lse_fa3).abs().mean().item() assert ((lse_max_ref == math.inf and lse_max_fa3 == math.inf) or lse_max_diff <= 1e-3) assert ((lse_mean_ref == math.inf and lse_mean_fa3 == math.inf) or lse_mean_diff <= 1e-4) if __name__ == "__main__": main() ================================================ FILE: hopper/test_flash_attn.py ================================================ import os import math import itertools import pytest import torch import torch.nn.functional as F from torch._C import parse_schema from torch.testing._internal.optests.generate_tests import ( safe_fake_check, safe_schema_check, safe_aot_autograd_check, ) from einops import rearrange, repeat try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None from padding import pad_input, unpad_input from test_util import ( attention_ref, generate_qkv, generate_random_padding_mask, ) from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" ENABLE_OPCHECK = os.getenv("FLASH_ATTENTION_ENABLE_OPCHECK", "FALSE") == "TRUE" ENABLE_AUTOGRAD_CHECK = os.getenv("FLASH_ATTENTION_ENABLE_AUTOGRAD_CHECK", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] + ([64] if not DISABLE_HDIM64 else []) + ([96] if not DISABLE_HDIM96 else []) + ([128] if not DISABLE_HDIM128 else []) + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) def should_test_backward(args, kwargs): v = args[2] num_splits = kwargs.get("num_splits", 1) dtype = v.dtype has_qv = V_colmajor = False # no test runs this with V_colmajor or has_qv == True attention_chunk = kwargs.get("attention_chunk") dv = v.size(-1) if ( ENABLE_AUTOGRAD_CHECK and not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not dv > 256 and not attention_chunk != 0 and num_splits > 0 # we don't support num_split == 0 on torch.compile yet ): return True return False def should_run_schema_check(args, kwargs): v = args[2] if v.dtype == torch.float8_e4m3fn: return False return True def should_run_fake_check(args, kwargs): if 'num_splits' in kwargs: return kwargs['num_splits'] > 0 return True def run_opcheck(fn): def wrapper(*args, **kwargs): if should_run_schema_check(args, kwargs): safe_schema_check(fn, args, kwargs) if should_run_fake_check(args, kwargs): safe_fake_check(fn, args, kwargs) if should_test_backward(args, kwargs): # Expensive check safe_aot_autograd_check(fn, args, kwargs, dynamic=False) safe_aot_autograd_check(fn, args, kwargs, dynamic=True) return fn(*args, **kwargs) return wrapper if ENABLE_OPCHECK: flash_attn_func = run_opcheck(flash_attn_func) flash_attn_varlen_func = run_opcheck(flash_attn_varlen_func) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("has_qv", [False, True]) # @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (64, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] if has_qv: dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4) q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None if V_colmajor: v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() # if qv is not None: # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): print(f"{pack_gqa = }, {num_splits = }") out = flash_attn_func( q, k, v, causal=causal, qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not dv > 256 and not attention_chunk != 0 ): g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( # g, # q, # k, # v, # out, # lse, # None, # None, # None, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("has_qv", [False, True]) # @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (1, 3), (2, 1), (511, 1), (3, 513), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (307, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (1024, 1024), (2048, 2048), (4096, 4096), ], ) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 # batch_size = 32 nheads = 6 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 # nheads_kv = nheads dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] if has_qv: dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random", zero_lengths=False ) key_padding_mask = generate_random_padding_mask( seqlen_k, batch_size, device, mode="random", zero_lengths=True ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) attn_mask = torch.logical_and(padding_mask, another_mask) unused_mask = torch.logical_xor( torch.logical_or(padding_mask, another_mask), attn_mask ) else: attn_mask = padding_mask unused_mask = None return attn_mask, unused_mask query_padding_mask, query_unused_mask = _gen_unused_masks( query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device ) key_padding_mask, key_unused_mask = _gen_unused_masks( key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) ( q_unpad, k_unpad, v_unpad, qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q, k, v, qv, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] # pack_gqa_vals = [False] num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] # num_splits_vals = [1] # print("cu_seqlens_q: ", cu_seqlens_q) # print("cu_seqlens_k: ", cu_seqlens_k) # print("seqused_q: ", seqused_q) # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): print(f"{pack_gqa = }, {num_splits = }") out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, seqused_q=seqused_q, seqused_k=seqused_k, causal=causal, qv=qv_unpad, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( # g_unpad, # q_unpad, # k_unpad, # v_unpad, # out_unpad, # lse, # None, # None, # None, # cu_seqlens_q, # cu_seqlens_k, # None, None, # max_seqlen_q, # max_seqlen_k, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 g = output_pad_fn(g_unpad) # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) # @pytest.mark.parametrize("page_size", [None]) @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) # @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), # (1, 128 * 1024), # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, varlen_q, has_batch_idx, has_leftpad, page_size, rotary_fraction, rotary_interleaved, has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, new_kv, mha_type, dtype, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if rotary_fraction == 0.0 and has_rotary_seqlens: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 # batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv = None if varlen_q: query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None else: query_padding_mask = None q_unpad = q qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() cu_seqlens_k_new = None key_new_padding_mask = None if new_kv: k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if varlen_q: # k & v are also varlen key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) v_unpad, *rest = unpad_input(v, key_new_padding_mask) else: k_unpad, v_unpad = k, v else: k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) page_table = None else: ( k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) if has_leftpad: cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size)]) else: cache_leftpad = None if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size ] else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if not new_kv: key_padding_mask = arange < cache_seqlens_expanded else: k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 if rotary_dim > 0: angle = ( torch.rand( seqlen_k if page_size is None else num_blocks * page_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens ) k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: k_to_update = k_to_update[indices_k] v_to_update = v_to_update[indices_k] k_cache_ref[update_mask] = k_to_update v_cache_ref[update_mask] = v_to_update k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, attention_chunk=attention_chunk, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, attention_chunk=attention_chunk, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None k = k.to(dtype) if k is not None else None v = v.to(dtype) if v is not None else None k_unpad = k_unpad.to(dtype) if k_unpad is not None else None v_unpad = v_unpad.to(dtype) if v_unpad is not None else None qv = qv.to(dtype) if qv is not None else None qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k if page_size is None else page_table.shape[1] * page_size, nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, num_splits=num_splits, ) else: scheduler_metadata = None # Repeat to test metadata reuse for _ in range(1 if not precompute_metadata else 2): if page_size is None: k_cache.copy_(k_cache_saved) v_cache.copy_(v_cache_saved) else: k_cache_paged.copy_(k_cache_saved) v_cache_paged.copy_(v_cache_saved) out, lse, *rest = flash_attn_with_kvcache( q if not varlen_q else q_unpad, k_cache if page_size is None else k_cache_paged, v_cache if page_size is None else v_cache_paged, k if not new_kv or not varlen_q else k_unpad, v if not new_kv or not varlen_q else v_unpad, qv=qv if not varlen_q else qv_unpad, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, page_table=page_table, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, attention_chunk=attention_chunk, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True, ) if varlen_q: out = output_pad_fn(out) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if page_size is None: k_cache_select = ( k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] ) v_cache_select = ( v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] ) else: k_cache_select = rearrange( k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) v_cache_select = rearrange( v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) # breakpoint() # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: if rotary_dim == 0: assert torch.equal(k_cache_select, k_cache_ref) else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() if dtype is not torch.float8_e4m3fn: assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) else: assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) mult = 4 if dtype == torch.float8_e4m3fn else 2 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref ).to(dtype).to(dtype_ref) v_cache_paged = torch.randn( num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref ).to(dtype).to(dtype_ref) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache = rearrange( k_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache = rearrange( v_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (64, 8192), ], ) def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): device = "cuda" torch.random.manual_seed(0) batch_size = 2 nheads = 16 nheads_kv = 4 # There was a bug where this would cause "unspecified launch failure" due to Cluster q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) for _ in range(100): flash_attn_func(q, k, v, causal=causal) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [80]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (239, 1), (3, 799), (799, 3), (1024, 128), (97, 97), (128, 128), (200, 200), (256, 256), (257, 257), (384, 384), (512, 512), (768, 768), (1024, 1024), (2048, 2048), ], ) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) # Simulate under memory load dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() for i in range(1000): torch.random.manual_seed(42) out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") # breakpoint() assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal def attention_combine_ref(out_partial, lse_partial): """ out_partial: (num_splits, batch_size, seqlen, nheads, d) lse_partial: (num_splits, batch_size, nheads, seqlen) """ lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): if DISABLE_SPLIT: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(1) batch_size = 5 nheads = 16 # batch_size = 1 # nheads = 1 out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor # To test short-circuiting based on num_splits lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) out_pt = out_ref.to(dtype) print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) multiple = 2 assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) # from flash_attn.utils.benchmark import pytorch_profiler # # pytorch_profiler(torch.sum, lse_partial) # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) # pytorch_profiler(torch.sum, out_partial) def test_flash3_bw_compatibility() -> None: # Let's try to always stay backward compatible! This will make life easier # for downstream libaries, users, and exported models. # 1/ Instead of removing arguments, error out if their value is no longer supported # 2/ When adding arguments, add them at the end with a default value assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " "-> (Tensor(out!), Tensor, Tensor, Tensor)" )) assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" )) assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" )) assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " "int sm_margin=0) -> Tensor" )) ================================================ FILE: hopper/test_flash_attn_bwd_determinism.py ================================================ import os import math import itertools import pytest import torch import torch.nn.functional as F from torch._C import parse_schema from einops import rearrange, repeat try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None from padding import pad_input, unpad_input from test_util import ( attention_ref, generate_qkv, generate_random_padding_mask, ) from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata from flash_attn_interface import _flash_attn_backward DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "FALSE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" # deterministic mode not supported for hdim 256 DISABLE_HDIM256 = True COMPILED_HDIMS = ( [] + ([64] if not DISABLE_HDIM64 else []) + ([96] if not DISABLE_HDIM96 else []) + ([128] if not DISABLE_HDIM128 else []) + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mqa"]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (64, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), # (4224, 4224), # (8192, 8192), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") if deterministic and d == 256: pytest.skip("Deterministic mode not supported for hdim 256") device = "cuda" # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) # if dtype == torch.float8_e4m3fn: # dv_vals = [d] # if has_qv: # dv_vals = [256, 512] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] dv_vals = [d] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4) q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None if V_colmajor: v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() # if qv is not None: # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] # num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] pack_gqa_vals = [False] num_splits_vals = [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): print(f"{pack_gqa = }, {num_splits = }") out, softmax_lse = flash_attn_func( q, k, v, causal=causal, qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits, return_attn_probs=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not dv > 256 and not attention_chunk != 0 ): g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) dq, dk, dv, softmax_d = _flash_attn_backward( g, q, k, v, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k, None, None, # sequed_q, sequed_k, None, None, # max_seqlen_q, max_seqlen_k, dq, dk, dv, d ** (-0.5), causal, window_size=window_size, softcap=softcap, deterministic=deterministic, ) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol if deterministic: iterations = 1000 for i in range(iterations): dq2 = torch.empty_like(dq) dk2 = torch.empty_like(dk) dv2 = torch.empty_like(dv) dq2, dk2, dv2, softmax_d = _flash_attn_backward( g, q, k, v, out, softmax_lse, None, None, # cu_seqlens_q, cu_seqlens_k, None, None, # sequed_q, sequed_k, None, None, # max_seqlen_q, max_seqlen_k, dq2, dk2, dv2, d ** (-0.5), causal, window_size=window_size, softcap=softcap, deterministic=deterministic, ) print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') assert torch.equal(dq, dq2), f"dq not deterministic" assert torch.equal(dk, dk2), f"dk not deterministic" assert torch.equal(dv, dv2), f"dv not deterministic" print(f"✅ Iteration {i} passed!") # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (1, 3), (2, 1), (511, 1), (3, 513), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (307, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (1024, 1024), (2048, 2048), (4096, 4096), ], ) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") if deterministic and d == 256: pytest.skip("Deterministic mode not supported for hdim 256") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 # batch_size = 32 nheads = 6 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 # nheads_kv = nheads dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) # if dtype == torch.float8_e4m3fn: # dv_vals = [d] # if has_qv: # dv_vals = [256, 512] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] dv_vals = [d] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random", zero_lengths=False ) key_padding_mask = generate_random_padding_mask( seqlen_k, batch_size, device, mode="random", zero_lengths=True ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) attn_mask = torch.logical_and(padding_mask, another_mask) unused_mask = torch.logical_xor( torch.logical_or(padding_mask, another_mask), attn_mask ) else: attn_mask = padding_mask unused_mask = None return attn_mask, unused_mask query_padding_mask, query_unused_mask = _gen_unused_masks( query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device ) key_padding_mask, key_unused_mask = _gen_unused_masks( key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) ( q_unpad, k_unpad, v_unpad, qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q, k, v, qv, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 # pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] # num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] pack_gqa_vals = [False] num_splits_vals = [1] print("cu_seqlens_q: ", cu_seqlens_q) print("cu_seqlens_k: ", cu_seqlens_k) print("seqused_q: ", seqused_q) print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): print(f"{pack_gqa = }, {num_splits = }") out_unpad, softmax_lse = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, seqused_q=seqused_q, seqused_k=seqused_k, causal=causal, qv=qv_unpad, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) dq_unpad = torch.empty_like(q_unpad) dk_unpad = torch.empty_like(k_unpad) dv_unpad = torch.empty_like(v_unpad) dq_unpad, dk_unpad, dv_unpad, softmax_d = _flash_attn_backward( g_unpad, q_unpad, k_unpad, v_unpad, out_unpad, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, dq_unpad, dk_unpad, dv_unpad, d ** (-0.5), causal, window_size=window_size, softcap=softcap, deterministic=deterministic, ) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 g = output_pad_fn(g_unpad) # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol print(dq_unpad.shape) print(dk_unpad.shape) print(dv_unpad.shape) print(dq.shape) print(dk.shape) print(dv.shape) if deterministic: iterations = 1000 for i in range(iterations): dq_unpad2 = torch.empty_like(q_unpad) dk_unpad2 = torch.empty_like(k_unpad) dv_unpad2 = torch.empty_like(v_unpad) dq_unpad2, dk_unpad2, dv_unpad2, softmax_d = _flash_attn_backward( g_unpad, q_unpad, k_unpad, v_unpad, out_unpad, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, dq_unpad2, dk_unpad2, dv_unpad2, d ** (-0.5), causal, window_size=window_size, softcap=softcap, deterministic=deterministic, ) dq2 = dq_pad_fn(dq_unpad2) dk2 = dk_pad_fn(dk_unpad2) dv2 = dk_pad_fn(dv_unpad2) if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk2.masked_fill_(k_zero_masking, 0.0) dv2.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq2.masked_fill_(q_zero_masking, 0.0) print(f'dq max diff with myself: {(dq2 - dq).abs().max().item()}') print(f'dk max diff with myself: {(dk2 - dk).abs().max().item()}') print(f'dv max diff with myself: {(dv2 - dv).abs().max().item()}') assert torch.equal(dq, dq2), f"dq not deterministic" assert torch.equal(dk, dk2), f"dk not deterministic" assert torch.equal(dv, dv2), f"dv not deterministic" print(f"✅ Iteration {i} passed!") ================================================ FILE: hopper/test_flash_attn_triton_amd.py ================================================ import os import math import itertools import pytest import torch import torch.nn.functional as F from torch._C import parse_schema from einops import rearrange, repeat try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None from padding import pad_input, unpad_input from test_util import ( attention_ref, generate_qkv, generate_random_padding_mask, ) from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "FALSE") == "TRUE" DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" COMPILED_HDIMS = ( [] + ([64] if not DISABLE_HDIM64 else []) + ([96] if not DISABLE_HDIM96 else []) + ([128] if not DISABLE_HDIM128 else []) + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("V_colmajor", [False, True]) @pytest.mark.parametrize("V_colmajor", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (64, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") device = "cuda" # set seed torch.random.manual_seed(0) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4) q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() # window_size = (-1, -1) if not local else (16, 0) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None if V_colmajor: v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() # if qv is not None: # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out = flash_attn_func( q, k, v, causal=causal, qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, pack_gqa=pack_gqa, num_splits=num_splits ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not dv > 256 and not attention_chunk != 0 ): g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( # g, # q, # k, # v, # out, # lse, # None, # None, # None, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (1, 3), (2, 1), (511, 1), (3, 513), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (307, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 nheads = 6 # batch_size = 2 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() if has_qv: qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random", zero_lengths=False ) key_padding_mask = generate_random_padding_mask( seqlen_k, batch_size, device, mode="random", zero_lengths=True ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) attn_mask = torch.logical_and(padding_mask, another_mask) unused_mask = torch.logical_xor( torch.logical_or(padding_mask, another_mask), attn_mask ) else: attn_mask = padding_mask unused_mask = None return attn_mask, unused_mask query_padding_mask, query_unused_mask = _gen_unused_masks( query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device ) key_padding_mask, key_unused_mask = _gen_unused_masks( key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) ( q_unpad, k_unpad, v_unpad, qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q, k, v, qv, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, seqused_q=seqused_q, seqused_k=seqused_k, causal=causal, qv=qv_unpad, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if ( not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 ): g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( # g_unpad, # q_unpad, # k_unpad, # v_unpad, # out_unpad, # lse, # None, # None, # None, # cu_seqlens_q, # cu_seqlens_k, # None, None, # max_seqlen_q, # max_seqlen_k, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 g = output_pad_fn(g_unpad) # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) # @pytest.mark.parametrize("causal,local", [(False, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) # @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) # @pytest.mark.parametrize("rotary_interleaved", [True]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) # @pytest.mark.parametrize("page_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), # (1, 128 * 1024), # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, varlen_q, has_batch_idx, has_leftpad, page_size, rotary_fraction, rotary_interleaved, has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, new_kv, mha_type, dtype, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if rotary_fraction == 0.0 and has_rotary_seqlens: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 # batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: qv = None if varlen_q: query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None else: query_padding_mask = None q_unpad = q qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() cu_seqlens_k_new = None key_new_padding_mask = None if new_kv: k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if varlen_q: # k & v are also varlen key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) v_unpad, *rest = unpad_input(v, key_new_padding_mask) else: k_unpad, v_unpad = k, v else: k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) page_table = None else: ( k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) if has_leftpad: cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size)]) else: cache_leftpad = None if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size ] else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if not new_kv: key_padding_mask = arange < cache_seqlens_expanded else: k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 if rotary_dim > 0: angle = ( torch.rand( seqlen_k if page_size is None else num_blocks * page_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens ) k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: k_to_update = k_to_update[indices_k] v_to_update = v_to_update[indices_k] k_cache_ref[update_mask] = k_to_update v_cache_ref[update_mask] = v_to_update k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, attention_chunk=attention_chunk, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, attention_chunk=attention_chunk, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None k = k.to(dtype) if k is not None else None v = v.to(dtype) if v is not None else None k_unpad = k_unpad.to(dtype) if k_unpad is not None else None v_unpad = v_unpad.to(dtype) if v_unpad is not None else None qv = qv.to(dtype) if qv is not None else None qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): if precompute_metadata: scheduler_metadata = get_scheduler_metadata( batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, num_splits=num_splits ) else: scheduler_metadata = None # Repeat to test metadata reuse for _ in range(1 if not precompute_metadata else 2): if page_size is None: k_cache.copy_(k_cache_saved) v_cache.copy_(v_cache_saved) else: k_cache_paged.copy_(k_cache_saved) v_cache_paged.copy_(v_cache_saved) out, lse, *rest = flash_attn_with_kvcache( q if not varlen_q else q_unpad, k_cache if page_size is None else k_cache_paged, v_cache if page_size is None else v_cache_paged, k if not new_kv or not varlen_q else k_unpad, v if not new_kv or not varlen_q else v_unpad, qv=qv if not varlen_q else qv_unpad, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, page_table=page_table, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, attention_chunk=attention_chunk, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True ) if varlen_q: out = output_pad_fn(out) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if page_size is None: k_cache_select = ( k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] ) v_cache_select = ( v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] ) else: k_cache_select = rearrange( k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) v_cache_select = rearrange( v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) # breakpoint() # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: if rotary_dim == 0: assert torch.equal(k_cache_select, k_cache_ref) else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() if dtype is not torch.float8_e4m3fn: assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) else: assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) mult = 4 if dtype == torch.float8_e4m3fn else 2 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref ).to(dtype).to(dtype_ref) v_cache_paged = torch.randn( num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref ).to(dtype).to(dtype_ref) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache = rearrange( k_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache = rearrange( v_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (64, 8192), ], ) def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): device = "cuda" torch.random.manual_seed(0) batch_size = 2 nheads = 16 nheads_kv = 4 # There was a bug where this would cause "unspecified launch failure" due to Cluster q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) for _ in range(100): flash_attn_func(q, k, v, causal=causal) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [80]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (239, 1), (3, 799), (799, 3), (1024, 128), (97, 97), (128, 128), (200, 200), (256, 256), (257, 257), (384, 384), (512, 512), (768, 768), (1024, 1024), (2048, 2048), ], ) @pytest.mark.skip(reason="Cannot be run in parallel with other tests due to memory usage") def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) # Simulate under memory load dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) out0 = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out0) dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() for i in range(1000): torch.random.manual_seed(42) out = flash_attn_func(q, k, v, causal=causal) assert torch.equal(out, out0) # assert torch.equal(lse, lse0) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") # breakpoint() assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal def attention_combine_ref(out_partial, lse_partial): """ out_partial: (num_splits, batch_size, seqlen, nheads, d) lse_partial: (num_splits, batch_size, nheads, seqlen) """ lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): if DISABLE_SPLIT: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(1) batch_size = 5 nheads = 16 # batch_size = 1 # nheads = 1 out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor # To test short-circuiting based on num_splits lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) out_pt = out_ref.to(dtype) print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) multiple = 2 assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) # from flash_attn.utils.benchmark import pytorch_profiler # # pytorch_profiler(torch.sum, lse_partial) # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) # pytorch_profiler(torch.sum, out_partial) @pytest.mark.skip(reason="AMD Triton backend doesn't use torch ops registration") def test_flash3_bw_compatibility() -> None: # Let's try to always stay backward compatible! This will make life easier # for downstream libaries, users, and exported models. # 1/ Instead of removing arguments, error out if their value is no longer supported # 2/ When adding arguments, add them at the end with a default value assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " "-> (Tensor(out!), Tensor, Tensor, Tensor)" )) assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" )) assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" )) assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " "int sm_margin=0) -> Tensor" )) ================================================ FILE: hopper/test_kvcache.py ================================================ import torch #from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache import flash_attn_interface as fa3 import flash_attn as fa2 import torch.utils.benchmark as benchmark import time import argparse import math parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--causal', action='store_true') parser.add_argument('--splits', type=int, default=1) parser.add_argument('--repeats', type=int, default=10) parser.add_argument('--validate', action='store_true') parser.add_argument('--gqa', action='store_true') args = parser.parse_args() def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs): """Use Pytorch Benchmark on the forward pass of an arbitrary function.""" if verbose: print(desc, '- Forward pass') t = benchmark.Timer( stmt='fn(**kwinputs)', globals={'fn': fn, 'kwinputs': kwinputs}, num_threads=torch.get_num_threads(), ) m = t.timeit(repeats) if verbose: print(desc, m) return t, m def benchmark_fa_kv(fn, repeats=10, *args, **kwargs): # warmup for _ in range(5): fn(*args, **kwargs) niters = repeats torch.cuda.synchronize() start = time.time() for _ in range(niters): fn(*args, **kwargs) torch.cuda.synchronize() end = time.time() return (end - start) / niters def main(): # *SAMPLE CONFIG* # Model arch params: nheads_q = 64 nheads_kv = 8 headdim = 128 #dtype = torch.bfloat16 dtype = torch.float16 # Cache settings: num_caches = 8 cache_seqlen = 1024 * 16 # Batching settings ntokens = 1024 max_queries_per_batch = 4 small_request_ntokens = 16 # Input settings query_seqlens = [900, 12, 1] num_queries = len(query_seqlens) # Need to add empty queries to fill out `max_queries_per_batch` num_padding_queries = max_queries_per_batch - num_queries context_seqlens = [4096, 5120*2, 6145*2] #context_seqlens = [4096, 5120*2, 6152*2] # Validation assert sum(query_seqlens) <= ntokens assert all(s < small_request_ntokens for s in query_seqlens[1:]) assert num_queries <= max_queries_per_batch assert all(s < cache_seqlen for s in context_seqlens) torch.manual_seed(5434) # Allocate some tensors k_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) v_cache = torch.randn( (num_caches, cache_seqlen, nheads_kv, headdim), device="cuda", dtype=dtype ) q_buf_large = torch.randn( (1, ntokens, nheads_q, headdim), device="cuda", dtype=dtype ) cache_seqlen_large = torch.tensor( [context_seqlens[0]], dtype=torch.int32, device="cuda" ) cache_idx_large = torch.tensor([1], dtype=torch.int32, device="cuda") q_buf_small = torch.randn( (max_queries_per_batch - 1, small_request_ntokens, nheads_q, headdim), device="cuda", dtype=dtype, ) cache_seqlens_small = torch.tensor( context_seqlens[1:] + [0] * num_padding_queries, dtype=torch.int32, device="cuda" ) cache_idxs_small = torch.randperm(num_caches, dtype=torch.int32, device="cuda")[ : max_queries_per_batch - 1 ] if args.validate: # Call flash attn # First for the single full-sized query out0, lse0 = fa3.flash_attn_with_kvcache( q=q_buf_large, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlen_large, cache_batch_idx=cache_idx_large, causal=bool(args.causal), num_splits=args.splits, return_softmax_lse=True, #num_splits=1 ) # Second for n-1 small queries out1_split1, lse1_split1 = fa3.flash_attn_with_kvcache( q=q_buf_small, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens_small, cache_batch_idx=cache_idxs_small, causal=bool(args.causal), num_splits=1, gqa_decoding=bool(args.gqa), return_softmax_lse=True, ) # Second for n-1 small queries out1, lse1 = fa3.flash_attn_with_kvcache( q=q_buf_small, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens_small, cache_batch_idx=cache_idxs_small, causal=bool(args.causal), num_splits=args.splits, gqa_decoding=bool(args.gqa), return_softmax_lse=True, ) # Call flash attn # First for the single full-sized query out2 = fa2.flash_attn_with_kvcache( q=q_buf_large, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlen_large, cache_batch_idx=cache_idx_large, causal=bool(args.causal), num_splits=args.splits, ) print ('big') print ('diff-max', (out0 - out2).abs().max().item(), cache_seqlens_small) print ('diff-mean', (out0 - out2).abs().mean().item()) # Second for n-1 small queries out3, lse_fa2 = fa2.flash_attn_with_kvcache( q=q_buf_small, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens_small, cache_batch_idx=cache_idxs_small, causal=bool(args.causal), num_splits=args.splits, return_softmax_lse=True, #num_splits=1 ) print ('small') #, out1) print ('lse', lse1, lse_fa2, (lse1 - lse_fa2).abs(), out1.shape) print ('lse-dif-max', (lse1 - lse_fa2).abs().max().item()) print ('diff-max', (out1 - out3).abs().max().item()) print ('diff-mean', (out1 - out3).abs().mean().item()) print ('fa3', args.repeats) time_fa3_big = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats, q=q_buf_large, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlen_large, cache_batch_idx=cache_idx_large, causal=bool(args.causal), num_splits=args.splits, ) time_fa3_small = benchmark_fa_kv(fa3.flash_attn_with_kvcache, repeats=args.repeats, q=q_buf_small, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens_small, cache_batch_idx=cache_idxs_small, causal=bool(args.causal), num_splits=args.splits, ) print ('fa2 ') time_fa2_big = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, q=q_buf_large, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlen_large, cache_batch_idx=cache_idx_large, causal=bool(args.causal), num_splits=args.splits ) time_fa2_small = benchmark_fa_kv(fa2.flash_attn_with_kvcache, repeats=args.repeats, q=q_buf_small, k_cache=k_cache, v_cache=v_cache, cache_seqlens=cache_seqlens_small, cache_batch_idx=cache_idxs_small, causal=bool(args.causal), num_splits=args.splits ) print ('big (split, fa3, fa2, ratio):', args.splits, time_fa3_big * 1000000, time_fa2_big * 1000000, time_fa3_big / time_fa2_big) print ('small (split, fa3, fa2, ratio):', args.splits, time_fa3_small * 1000000, time_fa2_small * 1000000, time_fa3_small / time_fa2_small) if __name__ == "__main__": main() ================================================ FILE: hopper/test_torch_compile_and_export.py ================================================ import torch from flash_attn_interface import flash_attn_func from torch import nn class EfficienctMultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=True): super().__init__() assert embed_size % num_heads == 0, f"{embed_size=} {num_heads=}" self.embed_size = embed_size self.num_heads = num_heads self.head_dim = embed_size // num_heads self.use_flash_attn = use_flash_attn and (flash_attn_func is not None) self.qkv_proj = nn.Linear(embed_size, 3 * embed_size) self.out_proj = nn.Linear(embed_size, embed_size) self.dropout = dropout def forward(self, x, attention_mask=None): N, seq_length, _ = x.shape qkv = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(N, seq_length, self.num_heads, self.head_dim) k = k.view(N, seq_length, self.num_heads, self.head_dim) v = v.view(N, seq_length, self.num_heads, self.head_dim) if self.use_flash_attn and attention_mask is None: out = flash_attn_func( q, k, v ) out = out.reshape(N, seq_length, self.embed_size) out = self.out_proj(out) return out def create_model(batch_size=16, sequence_length=256, embedding_dim=2048, num_heads=16): model = EfficienctMultiHeadAttention(embedding_dim, num_heads).cuda().bfloat16() input_tensor = torch.randn(batch_size, sequence_length, embedding_dim).cuda().bfloat16() return model, input_tensor def test_export_model(): model, input_tensor = create_model() expected = torch.compile(model, backend="aot_eager")(input_tensor) loss = expected.sum() loss.backward() ep = torch.export.export(model, (input_tensor,)) got = ep.module()(input_tensor,) assert torch.equal(expected, got) loss_2 = got.sum() loss_2.backward() assert torch.equal(loss, loss_2) def test_compile_and_package_model(): model, input_tensor = create_model() expected = torch.compile(model, backend="aot_eager")(input_tensor) exported = torch.export.export(model, (input_tensor,)) torch._inductor.aoti_compile_and_package( exported, package_path="model.pt2", ) compiled_model = torch._inductor.package.load_package("model.pt2") out = compiled_model(input_tensor,) assert torch.equal(expected, out) ================================================ FILE: hopper/test_util.py ================================================ import math import torch from einops import rearrange, repeat from padding import pad_input, unpad_input def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) if zero_lengths: # Generate zero-lengths every 5 batches and the last batch. for i in range(batch_size): if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( k, key_padding_mask, key_unused_mask ) v_unpad, *rest = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), ) def construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, device=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk # Subtract remainder instead of divide and then multiply to take care of negative values col_limit_left_chunk = row_idx + sk - sq - (row_idx + sk - sq) % attention_chunk return torch.logical_or( col_idx < col_limit_left_chunk, col_idx >= col_limit_left_chunk + attention_chunk ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, key_leftpad=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size attention_chunk=0, sink_token_length=0, softcap=0.0, upcast=True, reorder_ops=False, intermediate_dtype=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) v: (batch_size, seqlen_k, nheads, head_dim_v) qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None if q_descale is not None: q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) q = (q.float() * q_descale).to(q.dtype) qv = (qv.float() * q_descale).to(qv.dtype) if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] dv = v.shape[-1] softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if qv is not None: scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) local_mask = None if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, sink_token_length, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) if attention_chunk > 0: chunk_mask = construct_chunk_mask( seqlen_q, seqlen_k, attention_chunk, query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, device=q.device, ) local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) # Without this we might get NaN in dv if key_padding_mask is not None: attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) # Some rows might be completely masked out so we fill them with zero instead of NaN if local_mask is not None: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention if intermediate_dtype is not None: attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) ================================================ FILE: hopper/tile_scheduler.hpp ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cutlass/fast_math.h" #include "cutlass/arch/barrier.h" #include "named_barrier.hpp" #include "utils.h" namespace flash { /////////////////////////////////////////////////////////////////////////////// // Host side kernel arguments struct TileSchedulerArguments { // num_head is num_head_q if not PackGQA, else num_head_k int const num_blocks, num_head, num_batch, num_splits; int const qhead_per_khead; int const seqlen; // Only used if Varlen and cu_seqlens == nullptr and seqused == nullptr int const seqlen_k, headdim, headdim_v, element_size; // Used to calculate L2 swizzling int* const tile_count_semaphore = nullptr; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; int const* const num_m_blocks_ptr = nullptr; int const* const varlen_batch_idx_ptr = nullptr; // int const* const num_n_blocks_ptr = nullptr; int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// template class SingleTileScheduler { public: using SharedStorage = int; // Device side kernel params struct Params { int const num_blocks, num_head, num_batch, num_splits; int const qhead_per_khead; int const seqlen; cutlass::FastDivmod nsplits_divmod; int const* const cu_seqlens; int const* const seqused; int const* const num_splits_dynamic_ptr = nullptr; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { assert(!Split || !Varlen || args.num_splits_dynamic_ptr != nullptr); assert(!Split || !Varlen || args.num_splits < (1 << 16)); // We use the top 16 bits to store num_splits return {args.num_blocks, args.num_head, args.num_batch, !Split ? 1 : args.num_splits, args.qhead_per_khead, args.seqlen, cutlass::FastDivmod(!Split ? 1 : args.num_splits), !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused, args.num_splits_dynamic_ptr}; } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(params.num_blocks), uint32_t((!Split ? 1 : params.num_splits) * params.num_head), uint32_t(params.num_batch)}; } struct WorkTileInfo { int block_idx = 0; int bidh = 0; int bidb = 0; int split_idx = 0; CUTLASS_DEVICE bool is_valid(Params const& params) const { return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { return {block_idx, bidh, bidb, !Split ? 0 : split_idx}; } }; CUTLASS_DEVICE SingleTileScheduler(SharedStorage* const smem_scheduler) { } template CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { WorkTileInfo work_info {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), 0}; if constexpr (Split) { int split_idx; work_info.bidh = params.nsplits_divmod.divmod(split_idx, work_info.bidh); work_info.split_idx = split_idx; } bool is_valid_tile = true; if constexpr (Varlen) { int seqlen = params.seqused ? params.seqused[work_info.bidb] : (params.cu_seqlens ? params.cu_seqlens[work_info.bidb + 1] - params.cu_seqlens[work_info.bidb] : params.seqlen); if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } is_valid_tile = work_info.block_idx * kBlock < seqlen; } if constexpr (Varlen && Split) { int num_splits_dynamic = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[work_info.bidb] : params.num_splits; is_valid_tile &= work_info.split_idx < num_splits_dynamic; // Use the top 16 bits to store num_splits work_info.split_idx |= (num_splits_dynamic << 16); } work_info.bidb = is_valid_tile ? work_info.bidb : -1; return work_info; } CUTLASS_DEVICE void init_consumer() const {} CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { return {0, 0, -1, 0}; } }; /////////////////////////////////////////////////////////////////////////////// template class StaticPersistentTileScheduler { public: using SharedStorage = int; // Device side kernel params struct Params { int total_blocks; cutlass::FastDivmod m_block_divmod, head_divmod; cutlass::FastDivmod nsplits_divmod; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { return {args.num_blocks * args.num_head * args.num_batch * (!Split ? 1 : args.num_splits), cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head * (!Split ? 1 : args.num_splits)), cutlass::FastDivmod(!Split ? 1 : args.num_splits)}; } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(num_sm)}; } struct WorkTileInfo { int tile_idx; CUTLASS_DEVICE bool is_valid(Params const& params) const { return tile_idx < params.total_blocks; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { int block, bidh, bidb; bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(block, tile_idx)); int split_idx = 0; if constexpr (Split) { bidh = params.nsplits_divmod.divmod(split_idx, bidh); } return {block, bidh, bidb, split_idx}; } }; CUTLASS_DEVICE StaticPersistentTileScheduler(SharedStorage* const smem_scheduler) {}; template CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { return {int(blockIdx.x)}; } CUTLASS_DEVICE void init_consumer() const {} CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { return {current_work.tile_idx + int(gridDim.x)}; } }; /////////////////////////////////////////////////////////////////////////////// template class DynamicPersistentTileScheduler { // This scheduler targets the causal (or local) case where each tile takes different // amount of time. We use longest-processing-time-first scheduling: // the longest remaining tile is assigned to the first SM that's free. // SM indicates they are free by incrementing a semaphore. // However, we have to make sure K & V still fit into L2 cache, so we perform scheduling // on "sections" of the head & batch dimension, each section consisting of e.g. 8 heads. // This is the L2 swizzling part. The size of each section is precomputed based on the // size of K & V and the L2 cache size. static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; public: using SharedStorage = int; protected: SharedStorage* const tile_count_smem; public: // Device side kernel params struct Params { int const total_blocks; cutlass::FastDivmod const m_block_divmod, head_divmod; cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; int* const tile_count_semaphore; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead // Need to be careful about the case where only one head will fit auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; // Seems faster if swizzle if a power of 2 int const swizzle = (size_l2 < size_one_kv_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_kv_head))) * (PackGQA ? 1 : args.qhead_per_khead); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; int const num_split_blocks = args.num_blocks * (!Split ? 1 : args.num_splits); // printf("num_split_blocks = %d, num_head = %d, num_batch = %d, swizzle = %d, PackGQA = %d, qhead_per_khead = %d, num_hb_remainder = %d\n", num_split_blocks, args.num_head, args.num_batch, swizzle, int(PackGQA), args.qhead_per_khead, num_hb_remainder); assert(args.tile_count_semaphore != nullptr); return {num_split_blocks * args.num_head * args.num_batch, cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * num_split_blocks), // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), (args.num_head * args.num_batch) / swizzle, args.tile_count_semaphore}; } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(num_sm)}; } struct WorkTileInfo { int tile_idx; CUTLASS_DEVICE bool is_valid(Params const& params) const { return tile_idx < params.total_blocks; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { int block, bidh, bidb; int l2_mod, bidhb, bidhb_residual; bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. if (bidhb < params.num_hb_quotient) { block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); } else { block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); } bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); int split_idx = 0; if constexpr (Split) { split_idx = params.m_block_divmod.divmod(block, block); } // Longest-processing-time-first block = params.m_block_divmod.divisor - 1 - block; return {block, bidh, bidb, split_idx}; } }; CUTLASS_DEVICE DynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : tile_count_smem(smem_scheduler) {}; template CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { return {int(blockIdx.x)}; } CUTLASS_DEVICE void init_consumer() const { if (WarpSpecialized || cutlass::canonical_warp_idx_sync() > 0) { flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty } } CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { if (threadIdx.x % NumProducerThreads == 0) { current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); } } template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { if constexpr (IsProducerWarp) { // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % NumProducerThreads == 0) { *tile_count_smem = current_work.tile_idx; } flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return {new_tile_idx}; } else { flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int tile_idx = *tile_count_smem; flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return {tile_idx}; } } }; /////////////////////////////////////////////////////////////////////////////// template class SingleTileBwdLPTScheduler { public: using SharedStorage = int; // Device side kernel params struct Params { int const total_blocks; cutlass::FastDivmod const block_divmod, head_divmod; cutlass::FastDivmod const l2_minor_divmod, l2_major_divmod; cutlass::FastDivmod const l2_minor_residual_divmod; int const num_hb_quotient; int const seqlen; int const* const cu_seqlens; int const* const seqused; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum // Swizzle is the size of each "section". Round swizzle to a power of 2 // Need to be careful about the case where only one head will fit auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; // Seems faster if swizzle if a power of 2 int const swizzle = size_l2 < size_one_head ? 1 : (1 << find_log2_floor(size_l2 / size_one_head)); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. int const num_hb_remainder = (args.num_head * args.num_batch) % swizzle; // printf("num_blocks = %d, num_head = %d, num_batch = %d, size_one_head = %d, ratio = %d, swizzle = %d, num_hb_remainder = %d\n", args.num_blocks, args.num_head, args.num_batch, size_one_head, size_l2 / size_one_head, swizzle, num_hb_remainder); assert(args.tile_count_semaphore != nullptr); return {args.num_blocks * args.num_head * args.num_batch, cutlass::FastDivmod(args.num_blocks), cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(swizzle), cutlass::FastDivmod(swizzle * args.num_blocks), // don't divide by 0 cutlass::FastDivmod(num_hb_remainder > 0 ? num_hb_remainder : 1), (args.num_head * args.num_batch) / swizzle, args.seqlen, !Varlen ? nullptr : args.cu_seqlens, !Varlen ? nullptr : args.seqused}; } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(params.total_blocks)}; } struct WorkTileInfo { int block; int bidh; int bidb; CUTLASS_DEVICE bool is_valid(Params const& params) const { return bidb >= 0; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { return {block, bidh, bidb, 0 /*split_idx*/}; } }; CUTLASS_DEVICE SingleTileBwdLPTScheduler(SharedStorage* const smem_scheduler) { } template CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { int tile_idx = blockIdx.x; int block, bidh, bidb; int l2_mod, bidhb, bidhb_residual; bidhb = params.l2_major_divmod.divmod(l2_mod, tile_idx); // If we're in the last section (called residual), we don't want to divide by // swizzle. Instead we want to divide by the remainder. if (bidhb < params.num_hb_quotient) { block = params.l2_minor_divmod.divmod(bidhb_residual, l2_mod); } else { block = params.l2_minor_residual_divmod.divmod(bidhb_residual, l2_mod); } bidb = params.head_divmod.divmod(bidh, bidhb * params.l2_minor_divmod.divisor + bidhb_residual); bool is_valid_tile = true; int num_blocks; if constexpr (Varlen) { int seqlen = params.seqused ? params.seqused[bidb] : (params.cu_seqlens ? params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] : params.seqlen); num_blocks = cute::ceil_div(seqlen, Int{}); is_valid_tile = block < num_blocks; } else { num_blocks = params.block_divmod.divisor; } if constexpr (SPT) { block = num_blocks - block - 1; } return {block, bidh, is_valid_tile ? bidb : -1}; } CUTLASS_DEVICE void init_consumer() const {} CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { return {0, 0, -1}; } }; /////////////////////////////////////////////////////////////////////////////// template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); static constexpr int NumThreads = WarpSpecialized ? NumMmaThreads + NumProducerThreads : NumMmaThreads; public: using SharedStorage = int4; protected: SharedStorage* const work_info_smem; public: // Device side kernel params struct Params { int num_head, num_batch; int const qhead_per_khead; int const seqlen; // int const max_kvblocks_in_l2; cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; int const* const cu_seqlens; int const* const seqused; int const* const num_splits_dynamic_ptr; int const* const num_m_blocks_ptr; int const* const varlen_batch_idx_ptr; // int const* const num_n_blocks_ptr; int const* const num_nheads_in_l2_ptr; }; static Params to_underlying_arguments(TileSchedulerArguments const& args) { // If Split, for the purpose of scheduling, we pretend that instead there are // (args.num_splits * args.num_head) number of heads. assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits // int const size_l2 = 50 * 1024 * 1024; // 50 MB // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, // max_kvblocks_in_l2, cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, args.num_m_blocks_ptr, args.varlen_batch_idx_ptr, // aras.num_n_blocks_ptr, args.num_nheads_in_l2_ptr}; } static dim3 get_grid_shape(Params const& params, int num_sm) { return {uint32_t(num_sm)}; } struct WorkTileInfo { int tile_idx, block, bidh, bidb; CUTLASS_DEVICE bool is_valid(Params const& params) const { // if (blockIdx.x >= 0 && (threadIdx.x == 128 || threadIdx.x == 0)) { printf("blockIdx.x = %d, threadIdx.x = %d, checking valid, bidb = %d, params.num_batch = %d\n", blockIdx.x, threadIdx.x, bidb, params.num_batch); } return bidb < params.num_batch; } CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { auto get_actual_batch = [&](int virtual_batch) { if constexpr(Prepared && Sort) { return params.varlen_batch_idx_ptr[virtual_batch]; } else { return virtual_batch; } }; if constexpr (!Split) { return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift uint32_t bidh_packed = reinterpret_cast(bidh); uint32_t bidh_actual_u = bidh_packed & 0x0000FFFF; int bidh_actual = reinterpret_cast(bidh_actual_u); // Use the top 16 bits of split_idx to store num_splits and the next 16 bits to store split_idx uint32_t split_idx_u = ((bidh_packed & 0x00FF0000) >> 16) + ((bidh_packed & 0xFF000000) >> 8); int split_idx = reinterpret_cast(split_idx_u); // int bidh_actual = params.nsplits_divmod.divmod(split_idx, bidh); // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; CUTLASS_DEVICE VarlenDynamicPersistentTileScheduler(SharedStorage* const smem_scheduler) : work_info_smem(smem_scheduler) {}; CUTLASS_DEVICE WorkTileInfo tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkTileInfo const& current_work) const { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; if constexpr (Prepared) { return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 ? params.num_m_blocks_ptr[batch_idx] : 0; } else { int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); if (seqlen > kBlockM) { if (params.seqused) { seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; } else if (params.cu_seqlens) { int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); seqlen = next_cu_seqlen - cur_cu_seqlen; } else { seqlen = params.seqlen; } if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 ? cute::ceil_div(seqlen, kBlockM) : 0; // ? params.num_m_blocks_ptr[batch_idx] : 0; } }; auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; if constexpr (!Split) { return is_valid ? 1 : 0; } else if constexpr(Prepared) { return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; } else { return is_valid ? params.nsplits_divmod.divisor : 0; } }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane int num_splits = get_num_splits(current_work.bidb); int num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; // Cumulative number of blocks for the next 31 batches int num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); // Only the lower 16 bits are the actual bidh // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes // if constexpr (Split) { // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); // } // NEW: current_work.tile_idx holds group_start_tile for starting batch int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); // } // if (threadIdx.x == 0 && blockIdx.x == 0) { printf("tile_idx = %d, group_end_tile = %d, num_m_blocks_cumulative = %d, m_blocks_in_group = %d\n", current_work.tile_idx, group_end_tile, num_m_blocks_cumulative, m_blocks_in_group); } while (group_end_tile <= next_tile_idx) { bidb += cutlass::NumThreadsPerWarp - 1; if (bidb >= params.num_batch) { // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Returning early, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); // } return {next_tile_idx, 0, 0, params.num_batch}; } num_m_blocks = get_num_m_blocks(bidb); num_splits = get_num_splits(bidb); num_split_m_blocks = !Split ? num_m_blocks : num_m_blocks * num_splits; num_m_blocks_cumulative = warp_prefix_sum(num_split_m_blocks); m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); group_end_tile += m_blocks_in_group * params.num_head; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Bottom of while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group); // } } int group_start_tile = group_end_tile - m_blocks_in_group * params.num_head; // The next problem to process is the first one that does not have ending tile position // that is greater than or equal to tile index. int batch_idx_in_group = __popc(__ballot_sync(0xffffffff, group_start_tile + num_m_blocks_cumulative * params.num_head <= next_tile_idx)); // if (threadIdx.x == 31 || threadIdx.x == 0) { printf("blockIdx.x = %d, tidx %d, group_start_tile = %d, num_m_blocks_cumulative = %d, num_head = %d, next_tile_idx = %d, ballot = %x, batch_idx_in_group = %d\n", blockIdx.x, threadIdx.x, group_start_tile, num_m_blocks_cumulative, params.num_head, next_tile_idx, tmp, batch_idx_in_group); } bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; int mh_block = next_tile_idx - group_start_tile; int block, bidh; if constexpr (LPT) { if (!Split || num_splits == 1) { // NOTE: code for computing nheads_in_l2 directly left as reference // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } // nheads_in_l2 = min(nheads_in_l2, params.num_head); auto get_nheads_in_l2 = [&](int batch_idx) { if constexpr(Prepared) { return params.num_nheads_in_l2_ptr[batch_idx]; } else { return !PackGQA ? params.qhead_per_khead : 1; } }; int nheads_in_l2 = get_nheads_in_l2(bidb); int mh_in_l2 = nheads_in_l2 * num_m_blocks; int section_idx = mh_block / mh_in_l2; int l2_mod = mh_block - section_idx * mh_in_l2; // tail section int nheads_remainder = params.num_head - section_idx * nheads_in_l2; int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; block = l2_mod / nheads_in_this_section; int bidh_residual = l2_mod - block * nheads_in_this_section; bidh = section_idx * nheads_in_l2 + bidh_residual; if constexpr(Split) { // remember to set num_splits = 1 in work tile uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); bidh = reinterpret_cast(bidh_packed); } } else { // NOTE: leave traverse heads first version for reference // block = params.head_divmod.divmod(bidh, mh_block); // if constexpr (Split) { // int split_idx = block / num_m_blocks; // block = block - split_idx * num_m_blocks; // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); // bidh = reinterpret_cast(bidh_packed); // } bidh = mh_block / num_m_blocks; block = mh_block - bidh * num_m_blocks; if constexpr (Split) { int bidh_actual = bidh / num_splits; int split_idx = bidh - bidh_actual * num_splits; uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); bidh = reinterpret_cast(bidh_packed); } } block = num_m_blocks - 1 - block; } else { bidh = mh_block / num_m_blocks; block = mh_block - bidh * num_m_blocks; if constexpr (Split) { int bidh_actual = bidh / num_splits; int split_idx = bidh - bidh_actual * num_splits; // TODO: idk why this gives wrong answer nondeterministically // int bidh_actual, split_idx; // split_idx = params.head_divmod.divmod(bidh_actual, bidh); // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); // if (threadIdx.x == 0) { // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); // } bidh = reinterpret_cast(bidh_packed); } // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } } return {group_start_tile, block, bidh, bidb}; } template CUTLASS_DEVICE WorkTileInfo get_initial_work(Params const& params) const { if constexpr (IsProducerWarp) { WorkTileInfo work_info = tile_idx_to_work_tile(params, int(blockIdx.x), {0, 0, 0, 0}); if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { return get_next_work(params, {0, 0, 0, 0}); } } CUTLASS_DEVICE void init_consumer() const { // Don't arrive at the TileCountSmemEmpty barrier here, because get_initial_work will do that } CUTLASS_DEVICE void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { if (threadIdx.x % NumProducerThreads == 0) { current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); } } template CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params, WorkTileInfo const& current_work) const { if constexpr (IsProducerWarp) { // thread 0 has the next tile_idx, just need to broadcast to the rest of warp 0 int new_tile_idx = __shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/); WorkTileInfo work_info = {__shfl_sync(0xffffffff, current_work.tile_idx, 1 /*lane*/), current_work.block, current_work.bidh, current_work.bidb}; work_info = tile_idx_to_work_tile(params, new_tile_idx, work_info); flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { *work_info_smem = make_int4(work_info.tile_idx, work_info.block, work_info.bidh, work_info.bidb); } flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull return work_info; } else { flash::named_barrier_sync(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier1 /*id*/); // TileCountSmemFull int4 work_info = *work_info_smem; flash::named_barrier_arrive(NumThreads, cutlass::arch::ReservedNamedBarriers::StreamkBarrier0 /*id*/); // TileCountSmemEmpty return WorkTileInfo{work_info.x, work_info.y, work_info.z, work_info.w}; } } }; /////////////////////////////////////////////////////////////////////////////// } // flash ================================================ FILE: hopper/tile_size.h ================================================ /****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 if (headdim_v == 512) { return {64, 64, false, false}; } else if (headdim_v == 256) { return {128, 96, true, false}; } else { // Switch to tile size 192 x 192 for now bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {128, use_blockN_128 ? 128 : 176, true, true}; // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } } else { if (headdim <= 64) { return {192, 160, true, true}; } else if (headdim <= 96) { return {192, 128, true, true}; } else if (headdim <= 128) { return {128, paged_kv_non_TMA ? 160 : (v_colmajor || (softcap && is_local) ? 192 : 224), true, true}; } else if (headdim <= 192) { return {128, (paged_kv_non_TMA || softcap) && is_local ? 128 : 160, true, true}; } else { return {128, is_local ? 64 : 128, true, !paged_kv_non_TMA}; // PagedKV uses more registers so we disabled IntraWGOverlap } } } // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} constexpr std::tuple tile_size_fwd_sm8x( bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool paged_kv=false, bool varlen_and_split=false, bool softcap=false, bool append_kv=false) { if (element_size == 2) { if (headdim <= 64) { return {128, varlen_and_split ? 80 : (is_local ? 96 : 112), 4, 1, false}; } else if (headdim <= 96) { return {128, varlen_and_split || is_local ? 48 : 64, 4, 1, false}; } else if (headdim <= 128) { bool const use_8_warps = sm86_or_89 | varlen_and_split; return {128, use_8_warps ? (varlen_and_split ? (is_local ? 96 : 112) : (is_local ? 96 : 128)) : (is_local ? 48 : 64), use_8_warps ? 8 : 4, 1, use_8_warps}; } else if (headdim <= 192) { bool const kBlockN_64 = append_kv || is_local || varlen_and_split || paged_kv; return {128, kBlockN_64 ? 64 : 96, 8, sm86_or_89 ? 1 : 2, !kBlockN_64}; } else { return {128, sm86_or_89 ? (append_kv ? 32 : (varlen_and_split || is_local ? 48 : 64)) : (append_kv ? 48 : (varlen_and_split || is_local ? 64 : 96)), 8, 1, sm86_or_89 && !append_kv}; } } else { // Placeholder for now return {128, 64, 8, 2, false}; } } ================================================ FILE: hopper/utils.h ================================================ /****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include #include #include #include #include "cuda_check.h" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// // A wrapper for the kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // Adapted from https://github.com/vllm-project/vllm/blob/4d29e91be84d27ca313d657eee92c067439a4c23/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh#L55 template struct enable_sm90 : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900) Kernel::operator()(std::forward(args)...); #endif } }; template struct enable_sm80_to_sm89 : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 890) Kernel::operator()(std::forward(args)...); #endif } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// CUTLASS_HOST_DEVICE int div_floor(cutlass::FastDivmod const& divmod, int dividend) { // Take care of the negative case: https://stackoverflow.com/questions/39304681/division-with-negative-dividend-but-rounded-towards-negative-infinity // Maybe the compiler will turn the -1 - * into bit negation operation, I haven't checked. return dividend >= 0 ? divmod.divide(dividend) : -1 - divmod.divide(-1 - dividend); } CUTLASS_HOST_DEVICE int round_down(cutlass::FastDivmod const& divmod, int dividend) { return div_floor(divmod, dividend) * divmod.divisor; } CUTLASS_HOST_DEVICE int round_up(cutlass::FastDivmod const& divmod, int dividend) { return div_floor(divmod, dividend - 1) * divmod.divisor + divmod.divisor; } //////////////////////////////////////////////////////////////////////////////////////////////////// // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) template CUTLASS_DEVICE auto convert_layout_acc_rowcol(Layout0 acc_layout) { if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); auto l = acc_layout; if constexpr (!Transposed) { return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); } else { return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); } } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) if constexpr (!Transposed) { return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); } else { return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. // For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) // For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) template CUTLASS_DEVICE auto convert_layout_acc_Aregs(Layout0 acc_layout) { using X = Underscore; if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); } else { static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) // This combines the first two modes (<0, 0> and <0, 1>) into one mode. // Will require register shuffling later to be correct. return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) // This combination is right but doesn't work with register shuffling. // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), // get<1>(acc_layout), // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); } } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); if constexpr (mma_shape_K == 8) { return acc_layout; } else { auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE auto convert_type_unsafe(Tensor const &tensor) { using From_type = typename Engine::value_type; static constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; // HACK: this requires tensor to be "contiguous" auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); // Unsafe because we're returning a tensor with memory allocated on the stack. If the compiler does not // inline this function, then the memory might not be valid. } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. using From_type = typename Engine::value_type; using To_type = typename EngineOut::value_type; static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); Tensor frag = recast const>(tensor); Tensor out_frg = recast>(out); static_assert(size(frag) == size(out_frg)); cutlass::NumericArrayConverter convert_op; #pragma unroll for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Blocks until all but N previous cp.async.commit_group operations have committed. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all // (which is equivalent to commit_group then wait_group 0). // Instead we just call cp.async.wait_group 0, which is slightly faster. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 template CUTE_HOST_DEVICE void cp_async_wait() { #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE auto mma_partition_fragment_AB(Mma const& mma, Tensor0 const& tensor0) { if constexpr (A) { return mma.partition_fragment_A(tensor0); } else { return mma.partition_fragment_B(tensor0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void gemm(TiledMma& tiled_mma, Tensor0 const& tCrA, Tensor1 const& tCrB, Tensor2& tCrC) { if constexpr (M_slice >= 0) { static constexpr int MMA_M = decltype(size<1>(tCrC))::value; static_assert(M_slice < MMA_M); // After logical_divide, C has shape ((2,2,V), (MMA_M, 1), MMA_N) Tensor tCrC_slice = cute::logical_divide(tCrC, Shape>{})(_, make_coord(Int{}, _), _); if constexpr (!SwapAB) { Tensor tCrA_slice = cute::logical_divide(tCrA, Shape>{})(_, make_coord(Int{}, _), _); gemm(tiled_mma, tCrA_slice, tCrB, tCrC_slice); } else { Tensor tCrB_slice = cute::logical_divide(tCrB, Shape>{})(_, make_coord(Int{}, _), _); gemm(tiled_mma, tCrA, tCrB_slice, tCrC_slice); } } else { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { if constexpr (!SwapAB) { warpgroup_fence_operand(const_cast(tCrA)); } else { warpgroup_fence_operand(const_cast(tCrB)); } } warpgroup_fence_operand(tCrC); warpgroup_arrive(); if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } static constexpr int kNumKIters = CUTE_STATIC_V(size<2>(tCrA)); static constexpr int kMaxKIters = 16; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < std::min(kNumKIters, kMaxKIters); ++k_block) { if constexpr (!SwapAB) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } else { cute::gemm(tiled_mma, tCrB(_,_,k_block), tCrA(_,_,k_block), tCrC); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } // In the case of large kNumKIters, the compiler chooses to store the smem addresses // in registers, causing spills. This loop forces the compiler to recompute the addresses. if constexpr (kNumKIters > kMaxKIters) { // This will always be zero, just a way to force the compiler to recompute the smem // addresses. This results in USEL instructions. There's probably a better way to do this. int const k_offset = cutlass::canonical_warp_group_idx() < 128 ? 0 : 1; CUTLASS_PRAGMA_UNROLL for (int k_block = kMaxKIters; k_block < kNumKIters; ++k_block) { if constexpr (!SwapAB) { cute::gemm(tiled_mma, tCrA(_,_,k_block + k_offset), tCrB(_,_,k_block + k_offset), tCrC); } else { cute::gemm(tiled_mma, tCrB(_,_,k_block + k_offset), tCrA(_,_,k_block + k_offset), tCrC); } tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } warpgroup_commit_batch(); if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { if constexpr (!SwapAB) { warpgroup_fence_operand(const_cast(tCrA)); } else { warpgroup_fence_operand(const_cast(tCrB)); } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void gemm_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, Hook fn) { if constexpr (SwapAB) { gemm_sm80(acc, tCrB, tCrA, tCsB, tCsA, tiled_mma, smem_tiled_copy_B, smem_tiled_copy_A, smem_thr_copy_B, smem_thr_copy_A, fn); } else { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } if constexpr (!std::is_same_v) { if (i == 0) { fn(); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void gemm_rs_sm80(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void gemm_sm100(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { static constexpr int rA = decltype(rank(tA))::value; static constexpr int rB = decltype(rank(tB))::value; static constexpr int rC = decltype(rank(tC))::value; static_assert(rA == 3 && rB == 3 && rC == 3); if constexpr (zero_init) { atom.accumulate_ = decltype(atom.accumulate_)::Zero; } CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tA); k_block++) { cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); atom.accumulate_ = decltype(atom.accumulate_)::One; } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( TiledMMA, cute::C, cute::integral_constant, cute::integral_constant, cute::integral_constant, cute::integral_constant>, TAs...>, TMs...>) { return TiledMMA>, TAs...>, TMs...>{}; } template CUTE_HOST_DEVICE constexpr auto to_tiled_mma_sm100_ts( TiledMMA, TAs...>, TMs...>) { return TiledMMA, TAs...>, TMs...>{}; } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void copy(TiledCopy const &tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { // Decay TiledCopy to CopyAtom auto copy_atom = static_cast(tiled_copy); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { bool predicate_mn = Is_even_MN || get<0>(identity_MN(_0{}, m, _0{})) < max_MN; if constexpr (Is_even_MN || !Clear_OOB_MN) { if (Is_even_MN || predicate_mn) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if constexpr (Is_even_K || !Clear_OOB_K) { if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } } else { // Clear_OOB_K == true && Is_even_K == false // If copy traits can be transformed with a predicate value, do it, otherwise branch here if constexpr (has_with_bool) { cute::copy(copy_atom.with(predicate_K(k)), S(_, m, k), D(_, m, k)); } else { if (predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } else { cute::clear(D(_, m, k)); } } } } } } else { // Clear_OOB_MN == true && Is_even_MN == false, also implies Clear_OOB_K == true if constexpr (!has_with_bool) { if (predicate_mn) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(copy_atom, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else { cute::clear(D(_, m, _)); } } else { // combine the mn predicate with the k predicate #pragma unroll for (int k = 0; k < size<2>(S); ++k) { cute::copy(copy_atom.with(predicate_mn && (Is_even_K || predicate_K(k))), S(_, m, k), D(_, m, k)); } } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Byte permute and shuffle to match register layout of // (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II. template CUTLASS_DEVICE void permute_Aregs_fp8(Fragment &frag) { // frag has shape ((4, 2, 2), MMA_M, MMA_N), each element is 8 bits static_assert(decltype(size<0, 0>(frag))::value == 4); static_assert(decltype(size<0, 1>(frag))::value == 2); static_assert(decltype(stride<0, 0>(frag))::value == 1); static_assert(decltype(stride<0, 1>(frag))::value == 4); static_assert(sizeof(typename Fragment::value_type) == 1); int quad_idx = threadIdx.x % 4; bool lane_03 = quad_idx == 0 || quad_idx == 3; int selector_upper = lane_03 ? 0x5410 : 0x1054; int selector_lower = lane_03 ? 0x7632 : 0x3276; static constexpr int upper_map[4] = {0, 3, 1, 2}; // static constexpr int lower_map[4] = {1, 2, 0, 3}; Tensor frag_64b = recast(frag); // ((1, 1, 2), MMA_M, MMA_N) #pragma unroll for (int i = 0; i < size(frag_64b); ++i) { uint32_t upper = frag_64b[i].x; uint32_t lower = frag_64b[i].y; uint32_t upper0 = lane_03 ? upper : lower; uint32_t lower0 = lane_03 ? lower : upper; upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 1, 4); frag_64b[i].x = __byte_perm(upper0, lower0, selector_upper); frag_64b[i].y = __byte_perm(upper0, lower0, selector_lower); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits static_assert(decltype(size<0, 0>(frag))::value == 2); static_assert(decltype(size<0, 1>(frag))::value == 2); static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); static_assert(decltype(stride<0, 0>(frag))::value == 1); static_assert(sizeof(typename Fragment::value_type) == 4); Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) #pragma unroll for (int mi = 0; mi < size<1>(frag_64b); ++mi) { #pragma unroll for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void permute_output_fp8(Fragment &out) { // out has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits static_assert(decltype(size<0, 0>(out))::value == 2); static_assert(decltype(size<0, 1>(out))::value == 2); static_assert(decltype(size<0, 2>(out))::value % 2 == 0); static_assert(decltype(stride<0, 0>(out))::value == 1); static_assert(sizeof(typename Fragment::value_type) == 4); Tensor frag = group_modes<1, 3>(out); // ((2, 2, N / 8), (MMA_M, MMA_N)) #pragma unroll for (int mi = 0; mi < size<1>(frag); ++mi) { #pragma unroll for (int j = 0; j < size<0, 1>(frag); ++j) { #pragma unroll for (int i = 0; i < size<0, 2>(frag) / 2; ++i) { cutlass::swap(frag(make_coord(_1{}, j, 2 * i), mi), frag(make_coord(_0{}, j, 2 * i + 1), mi)); } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void permute_output_fp8_Vcolmajor(Fragment &frag) { // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 16 bits static_assert(decltype(size<0, 0>(frag))::value == 2); static_assert(decltype(size<0, 1>(frag))::value == 2); static_assert(decltype(stride<0, 0>(frag))::value == 1); static_assert(sizeof(typename Fragment::value_type) == 2 || sizeof(typename Fragment::value_type) == 4); int quad_idx = threadIdx.x % 4; bool lane_03 = quad_idx == 0 || quad_idx == 3; static constexpr int upper_map[4] = {0, 2, 3, 1}; // static constexpr int lower_map[4] = {2, 0, 1, 3}; // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } using type2 = std::conditional_t; Tensor frag_2 = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) // if (blockIdx.x == 0 && threadIdx.x == 128) { print(frag); printf("\n"); print(frag_2); } #pragma unroll for (int mi = 0; mi < size<1>(frag_2); ++mi) { #pragma unroll for (int j = 0; j < size<0, 1>(frag_2); ++j) { #pragma unroll for (int i = 0; i < size<0, 2>(frag_2) / 2; ++i) { type2 upper = frag_2(make_coord(_0{}, j, 2 * i), mi); type2 lower = frag_2(make_coord(_0{}, j, 2 * i + 1), mi); type2 upper0 = lane_03 ? upper : lower; type2 lower0 = lane_03 ? lower : upper; upper0 = __shfl_sync(uint32_t(-1), upper0, upper_map[quad_idx], 4); // lower0 = __shfl_sync(uint32_t(-1), lower0, lower_map[quad_idx], 4); lower0 = __shfl_sync(uint32_t(-1), lower0, upper_map[quad_idx] ^ 2, 4); frag_2(make_coord(_0{}, j, 2 * i), mi) = lane_03 ? upper0 : lower0; frag_2(make_coord(_0{}, j, 2 * i + 1), mi) = lane_03 ? lower0 : upper0; } } } // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(frag); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_DEVICE void apply_softcap(Tensor &tensor, float const softcap){ #pragma unroll for (int i = 0; i < size(tensor); ++i) { tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); } } template CUTLASS_DEVICE auto calculate_dtanh(Tensor &tensor){ Tensor out = make_fragment_like(tensor); #pragma unroll for (int i = 0; i < size(tensor); ++i) { out(i) = 1.f - (tensor(i) * tensor(i)); } return out; } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_DEVICE T warp_prefix_sum(T val) { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; CUTLASS_PRAGMA_UNROLL for (int i = 1; i < cutlass::NumThreadsPerWarp; i <<= 1) { T partial_sum = __shfl_up_sync(0xffffffff, val, i); if (lane >= i) { val += partial_sum; } } return val; } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_DEVICE T warp_uniform(T a) { return __shfl_sync(0xffffffff, a, 0); } //////////////////////////////////////////////////////////////////////////////////////////////////// CUTLASS_DEVICE int canonical_warp_group_idx_nosync() { return threadIdx.x / cutlass::NumThreadsPerWarpGroup; } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash ================================================ FILE: setup.py ================================================ # Copyright (c) 2023, Tri Dao. import sys import functools import warnings import os import re import ast import glob import shutil from pathlib import Path from typing import Literal, Optional from packaging.version import parse, Version import platform from setuptools import setup, find_packages import subprocess import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel import torch from torch.utils.cpp_extension import ( BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, ROCM_HOME, IS_HIP_EXTENSION, ) with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") if BUILD_TARGET == "auto": if IS_HIP_EXTENSION: IS_ROCM = True else: IS_ROCM = False else: if BUILD_TARGET == "cuda": IS_ROCM = False elif BUILD_TARGET == "rocm": IS_ROCM = True PACKAGE_NAME = "flash_attn" BASE_WHEEL_URL = ( "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" ) # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" ROCM_BACKEND: Optional[Literal["triton", "ck"]] = None if IS_ROCM: ROCM_BACKEND = "triton" if os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" else "ck" NVCC_THREADS = os.getenv("NVCC_THREADS") or "4" @functools.lru_cache(maxsize=None) def cuda_archs() -> str: return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") def get_platform(): """ Returns the platform name as used in wheel filenames. """ if sys.platform.startswith("linux"): return f'linux_{platform.uname().machine}' elif sys.platform == "darwin": mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) return f"macosx_{mac_version}_x86_64" elif sys.platform == "win32": return "win_amd64" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) return raw_output, bare_metal_version def add_cuda_gencodes(cc_flag, archs, bare_metal_version): """ Adds -gencode flags based on nvcc capabilities: - sm_80/90 (regular) - sm_100/120 on CUDA >= 12.8 - Use 100f on CUDA >= 12.9 (Blackwell family-specific) - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) - Embed PTX for newest arch for forward compatibility """ # Always-regular 80 if "80" in archs: cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] # Hopper 9.0 needs >= 11.8 if bare_metal_version >= Version("11.8") and "90" in archs: cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] # Blackwell 10.x requires >= 12.8 if bare_metal_version >= Version("12.8"): if "100" in archs: # CUDA 12.9 introduced "family-specific" for Blackwell (100f) if bare_metal_version >= Version("12.9"): cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] else: cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] if "120" in archs: # sm_120 is supported in CUDA 12.8/12.9+ toolkits if bare_metal_version >= Version("12.9"): cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] else: cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 if "110" in archs: if bare_metal_version >= Version("13.0"): cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] else: # Provide Thor support for CUDA 12.9 via sm_101 if bare_metal_version >= Version("12.8"): cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] # else: no Thor support in older toolkits # PTX for newest requested arch (forward-compat) numeric = [a for a in archs if a.isdigit()] if numeric: newest = max(numeric, key=int) cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] return cc_flag def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) def check_if_cuda_home_none(global_option: str) -> None: if CUDA_HOME is not None: return # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary # in that case. warnings.warn( f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " "only images whose names contain 'devel' will provide nvcc." ) def check_if_rocm_home_none(global_option: str) -> None: if ROCM_HOME is not None: return # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary # in that case. warnings.warn( f"{global_option} was requested, but hipcc was not found." ) def detect_hipify_v2(): try: from torch.utils.hipify import __version__ from packaging.version import Version if Version(__version__) >= Version("2.0.0"): return True except Exception as e: print("failed to detect pytorch hipify version, defaulting to version 1.0.0 behavior") print(e) return False def append_nvcc_threads(nvcc_extra_args): return nvcc_extra_args + ["--threads", NVCC_THREADS] def rename_cpp_to_cu(cpp_files): for entry in cpp_files: shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") def validate_and_update_archs(archs): # List of allowed architectures allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] # Validate if each element in archs is in allowed_archs assert all( arch in allowed_archs for arch in archs ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" cmdclass = {} ext_modules = [] # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if IS_ROCM: if ROCM_BACKEND == "triton": if os.path.isdir(".git"): subprocess.run(["git", "submodule", "update", "--init", "third_party/aiter"], check=True) else: assert os.path.isdir("third_party/aiter"), ( "third_party/aiter is missing, please use source distribution or git clone" ) subprocess.run( [sys.executable, "-m", "pip", "install", "--no-build-isolation", "third_party/aiter"], check=True, ) elif ROCM_BACKEND == "ck": if os.path.isdir(".git"): subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) else: assert os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py"), ( "csrc/composable_kernel is missing, please use source distribution or git clone" ) else: # CUDA: cutlass submodule if os.path.isdir(".git"): subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) else: assert os.path.exists("csrc/cutlass/include/cutlass/cutlass.h"), ( "csrc/cutlass is missing, please use source distribution or git clone" ) if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) check_if_cuda_home_none("flash_attn") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.7"): raise RuntimeError( "FlashAttention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) # Build -gencode (regular + PTX + family-specific 'f' when available) add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) else: # No nvcc present; warnings already emitted above pass # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True nvcc_flags = [ "-O3", "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", # "--ptxas-options=-v", # "--ptxas-options=-O2", # "-lineinfo", # "-DFLASHATTENTION_DISABLE_BACKWARD", # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_ALIBI", # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", ] compiler_c17_flag=["-O3", "-std=c++17"] # Add Windows-specific flags if sys.platform == "win32" and os.getenv('DISTUTILS_USE_SDK') == '1': nvcc_flags.extend(["-Xcompiler", "/Zc:__cplusplus"]) compiler_c17_flag=["-O2", "/std:c++17", "/Zc:__cplusplus"] ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=[ "csrc/flash_attn/flash_api.cpp", "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ "cxx": compiler_c17_flag, "nvcc": append_nvcc_threads(nvcc_flags + cc_flag), }, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], ) ) elif not SKIP_CUDA_BUILD and IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) # Skips CK C++ extension compilation if using Triton Backend if ROCM_BACKEND == "ck": ck_dir = "csrc/composable_kernel" #use codegen get code dispatch if not os.path.exists("./build"): os.makedirs("build") optdim = os.getenv("OPT_DIM", "32,64,128,256") subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 generator_flag = [] torch_dir = torch.__path__[0] if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): generator_flag = ["-DOLD_GENERATOR_PATH"] check_if_rocm_home_none("flash_attn") archs = os.getenv("GPU_ARCHS", "native").split(";") validate_and_update_archs(archs) if archs != ['native']: cc_flag = [f"--offload-arch={arch}" for arch in archs] else: arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] cc_flag = [f"--offload-arch={arch}"] # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 if FORCE_CXX11_ABI: torch._C._GLIBCXX_USE_CXX11_ABI = True sources = ["csrc/flash_attn_ck/flash_api.cpp", "csrc/flash_attn_ck/flash_common.cpp", "csrc/flash_attn_ck/mha_bwd.cpp", "csrc/flash_attn_ck/mha_fwd_kvcache.cpp", "csrc/flash_attn_ck/mha_fwd.cpp", "csrc/flash_attn_ck/mha_varlen_bwd.cpp", "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob( f"build/fmha_*wd*.cpp" ) # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro, # we must replace the incorrect APIs. maybe_hipify_v2_flag = [] if detect_hipify_v2(): maybe_hipify_v2_flag = ["-DHIPIFY_V2"] rename_cpp_to_cu(sources) renamed_sources = ["csrc/flash_attn_ck/flash_api.cu", "csrc/flash_attn_ck/flash_common.cu", "csrc/flash_attn_ck/mha_bwd.cu", "csrc/flash_attn_ck/mha_fwd_kvcache.cu", "csrc/flash_attn_ck/mha_fwd.cu", "csrc/flash_attn_ck/mha_varlen_bwd.cu", "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") cc_flag += ["-O3","-std=c++20", "-Wno-unknown-warning-option", "-fbracket-depth=1024", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-DCK_ENABLE_BF16", "-DCK_ENABLE_BF8", "-DCK_ENABLE_FP16", "-DCK_ENABLE_FP32", "-DCK_ENABLE_FP64", "-DCK_ENABLE_FP8", "-DCK_ENABLE_INT8", "-DCK_USE_XDL", "-DUSE_PROF_API=1", # "-DFLASHATTENTION_DISABLE_BACKWARD", "-D__HIP_PLATFORM_HCC__=1"] cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] # Imitate https://github.com/ROCm/composable_kernel/blob/c8b6b64240e840a7decf76dfaa13c37da5294c4a/CMakeLists.txt#L190-L214 hip_version = get_hip_version() if hip_version > Version('5.5.00000'): cc_flag += ["-mllvm", "--lsr-drop-solution=1"] if hip_version > Version('5.7.23302'): cc_flag += ["-fno-offload-uniform-block"] if hip_version > Version('6.1.40090'): cc_flag += ["-mllvm", "-enable-post-misched=0"] if hip_version > Version('6.2.41132'): cc_flag += ["-mllvm", "-amdgpu-early-inline-all=true", "-mllvm", "-amdgpu-function-calls=false"] if hip_version > Version('6.2.41133') and hip_version < Version('6.3.00000'): cc_flag += ["-mllvm", "-amdgpu-coerce-illegal-types=1"] extra_compile_args = { "cxx": ["-O3", "-std=c++20"] + generator_flag + maybe_hipify_v2_flag, "nvcc": cc_flag + generator_flag + maybe_hipify_v2_flag, } include_dirs = [ Path(this_dir) / "csrc" / "composable_kernel" / "include", Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", ] ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=renamed_sources, extra_compile_args=extra_compile_args, include_dirs=include_dirs, ) ) def get_package_version(): with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") if local_version: return f"{public_version}+{local_version}" else: return str(public_version) def get_wheel_url(): torch_version_raw = parse(torch.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() if IS_ROCM: torch_hip_version = get_hip_version() hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" else: # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}" # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) return wheel_url, wheel_filename class CachedWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all flash attention installs). We use the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ def run(self): if FORCE_BUILD: return super().run() wheel_url, wheel_filename = get_wheel_url() print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) # Make the archive # Lifted from the root wheel processing command # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 if not os.path.exists(self.dist_dir): os.makedirs(self.dist_dir) impl_tag, abi_tag, plat_tag = self.get_tag() archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() class NinjaBuildExtension(BuildExtension): def __init__(self, *args, **kwargs) -> None: # do not override env MAX_JOBS if already exists if not os.environ.get("MAX_JOBS"): import psutil nvcc_threads = max(1, int(NVCC_THREADS)) # calculate the maximum allowed NUM_JOBS based on cores max_num_jobs_cores = max(1, os.cpu_count() // 2) # calculate the maximum allowed NUM_JOBS based on free memory free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB # Assume worst-case peak observed memory usage of ~5GB per NVCC thread. # Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory. max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads))) # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) print( f"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. " "If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value." ) os.environ["MAX_JOBS"] = str(max_jobs) super().__init__(*args, **kwargs) # Build install_requires based on platform if ROCM_BACKEND == "triton": # Note: torch is excluded because pip resolves it to CUDA PyTorch from PyPI, overwriting any pre-installed ROCm PyTorch. Users must have torch installed. install_requires = [ "einops", "triton==3.5.1", ] else: install_requires = [ "torch", "einops", ] setup( name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( exclude=( "build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info", "flash_attn.cute", "flash_attn.cute.*", ) ), author="Tri Dao", author_email="tri@tridao.me", description="Flash Attention: Fast and Memory-Efficient Exact Attention", long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/Dao-AILab/flash-attention", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], ext_modules=ext_modules, cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension} if ext_modules else { "bdist_wheel": CachedWheelsCommand, }, python_requires=">=3.9", install_requires=install_requires, setup_requires=[ "packaging", "psutil", "ninja", ], ) ================================================ FILE: tests/cute/benchmark_block_sparsity.py ================================================ """ Comparative benchmark: CuTe DSL vs Native PyTorch block sparsity computation. """ import torch from dataclasses import dataclass from typing import Callable, Optional, List from tabulate import tabulate from tqdm import tqdm import itertools from cutlass.cute.runtime import from_dlpack from cutlass.cute.testing import benchmark as cute_benchmark import cutlass.cute as cute from flash_attn.cute.compute_block_sparsity import BlockSparsityKernel from flash_attn.cute.block_sparsity import BlockSparseTensors from mask_mod_definitions import ( get_mask_pair, random_doc_id_tensor, flex_document_mask, cute_document_mask, ) from torch.nn.attention.flex_attention import create_block_mask from triton.testing import do_bench # Configure torch.compile cache to prevent memory buildup torch._dynamo.config.cache_size_limit = 1000 @dataclass(frozen=True) class BenchmarkConfig: """Configuration for a benchmark run.""" batch_size: int num_heads: int seqlen_q: int seqlen_k: int mask_name: str tile_m: int = 128 tile_n: int = 128 use_fast_sampling: bool = False aux_tensors_cute: Optional[list] = None @dataclass(frozen=True) class BenchmarkResult: """Result of a single benchmark run.""" config: BenchmarkConfig cute_time_ms: Optional[float] pytorch_time_ms: Optional[float] error_message: Optional[str] = None def benchmark_pytorch_block_sparsity( config: BenchmarkConfig, mask_fn: Callable, ) -> Optional[float]: """ Benchmark PyTorch block mask creation (compiled). Returns: creation_time_ms """ device = "cuda" try: cbm = torch.compile(create_block_mask) def run_benchmark(): return cbm( mask_fn, config.batch_size, config.num_heads, config.seqlen_q, config.seqlen_k, device=device, ) creation_time_ms = do_bench(run_benchmark, warmup=10, rep=100) return creation_time_ms except Exception as e: print(f"PyTorch benchmark failed ({config.mask_name}): {e}") import traceback traceback.print_exc() return None def benchmark_cute_block_sparsity( config: BenchmarkConfig, mask_fn: Callable, ) -> Optional[float]: """ Benchmark CuTe block sparsity kernel. Returns: creation_time_ms """ device = "cuda" try: num_m_blocks = (config.seqlen_q + config.tile_m - 1) // config.tile_m num_n_blocks = (config.seqlen_k + config.tile_n - 1) // config.tile_n mask_block_cnt = torch.zeros( (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32, ) mask_block_idx = torch.zeros( (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32, ) full_block_cnt = torch.zeros( (config.batch_size, config.num_heads, num_m_blocks), device=device, dtype=torch.int32, ) full_block_idx = torch.zeros( (config.batch_size, config.num_heads, num_m_blocks, num_n_blocks), device=device, dtype=torch.int32, ) # Convert to CuTe tensors mask_cnt_cute = from_dlpack( mask_block_cnt.detach(), assumed_align=4 ).mark_layout_dynamic(leading_dim=2) mask_idx_cute = from_dlpack( mask_block_idx.detach(), assumed_align=4 ).mark_layout_dynamic(leading_dim=3) full_cnt_cute = from_dlpack( full_block_cnt.detach(), assumed_align=4 ).mark_layout_dynamic(leading_dim=2) full_idx_cute = from_dlpack( full_block_idx.detach(), assumed_align=4 ).mark_layout_dynamic(leading_dim=3) blocksparse_tensors = BlockSparseTensors( mask_block_cnt=mask_cnt_cute, mask_block_idx=mask_idx_cute, full_block_cnt=full_cnt_cute, full_block_idx=full_idx_cute, ) # Create kernel use_aux = ( config.aux_tensors_cute is not None and len(config.aux_tensors_cute) > 0 ) kernel = BlockSparsityKernel( mask_mod=mask_fn, tile_mn=(config.tile_m, config.tile_n), compute_full_blocks=True, use_aux_tensors=use_aux, use_fast_sampling=config.use_fast_sampling, ) # Compile kernel compiled_kernel = cute.compile( kernel, blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute, ) def generate_tensors(): from cutlass.cute.testing import JitArguments return JitArguments( blocksparse_tensors, config.seqlen_q, config.seqlen_k, config.aux_tensors_cute, ) creation_time_us = cute_benchmark( compiled_kernel, workspace_generator=generate_tensors, warmup_iterations=10, iterations=100, ) torch.cuda.synchronize(device) creation_time_ms = creation_time_us / 1000.0 return creation_time_ms except Exception as e: print(f"CuTe benchmark failed: {e}") return None def run_benchmark( config: BenchmarkConfig, pytorch_mask_fn: Callable, cute_mask_fn: Callable, ) -> BenchmarkResult: """Run benchmarks for both implementations.""" print( f"Benchmarking {config.mask_name} - B={config.batch_size}, H={config.num_heads}, " f"M={config.seqlen_q}, N={config.seqlen_k}" ) # Benchmark PyTorch pytorch_time = benchmark_pytorch_block_sparsity(config, pytorch_mask_fn) # Benchmark CuTe cute_time = benchmark_cute_block_sparsity(config, cute_mask_fn) return BenchmarkResult( config=config, cute_time_ms=cute_time, pytorch_time_ms=pytorch_time, ) def generate_configs( batch_sizes: List[int], num_heads: List[int], seqlens: List[int], mask_names: List[str], ) -> List[BenchmarkConfig]: """Generate all benchmark configurations.""" configs = [] for B, H, S, mask_name in itertools.product( batch_sizes, num_heads, seqlens, mask_names ): configs.append( BenchmarkConfig( batch_size=B, num_heads=H, seqlen_q=S, seqlen_k=S, mask_name=mask_name, ) ) return configs def print_results(results: List[BenchmarkResult]): successful_results = [ r for r in results if r.cute_time_ms is not None and r.pytorch_time_ms is not None ] if not successful_results: print("No successful benchmark results to display") return headers = [ "B", "H", "M", "N", "Mask Type", "CuTe Time (ms)", "PyTorch Time (ms)", "Speedup", ] rows = [] for result in successful_results: speedup = ( result.pytorch_time_ms / result.cute_time_ms if result.cute_time_ms > 0 else 0 ) rows.append( [ result.config.batch_size, result.config.num_heads, result.config.seqlen_q, result.config.seqlen_k, result.config.mask_name, f"{result.cute_time_ms:.4f}", f"{result.pytorch_time_ms:.4f}", f"{speedup:.2f}x", ] ) # Sort by batch, head, seqlen, then mask type rows.sort(key=lambda x: (x[0], x[1], x[2], x[4])) print("\n" + "=" * 100) print("CuTe DSL vs PyTorch Block Sparsity Benchmark Results") print("=" * 100) print(tabulate(rows, headers=headers, tablefmt="github")) print("=" * 100) def main(): """Run the comparative benchmark.""" # Configuration batch_sizes = [1, 4, 8] num_heads = [8, 16] seqlens = [1024, 2048, 4096, 8192] mask_names = [ "causal", "sliding_window", "prefix_lm", "dilated_sliding_window", "document", ] device = "cuda" max_seqlen = max(seqlens) max_batch = max(batch_sizes) max_heads = max(num_heads) # Create document IDs using the helper from mask_definitions doc_ids = random_doc_id_tensor(max_heads, max_batch, max_seqlen, device=device) doc_ids_cute = from_dlpack(doc_ids.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=2 ) # Generate base configurations base_configs = generate_configs(batch_sizes, num_heads, seqlens, mask_names) # Update configs with aux tensors for document masking configs = [] for config in base_configs: if config.mask_name == "document": # Add aux tensors for document masking configs.append( BenchmarkConfig( batch_size=config.batch_size, num_heads=config.num_heads, seqlen_q=config.seqlen_q, seqlen_k=config.seqlen_k, mask_name=config.mask_name, tile_m=config.tile_m, tile_n=config.tile_n, use_fast_sampling=False, aux_tensors_cute=[doc_ids_cute], ) ) else: configs.append(config) # Run benchmarks results = [] print(f"Running {len(configs)} benchmark configurations...") for config in tqdm(configs, desc="Benchmarking"): try: # Get mask pair from mask_definitions mask_kwargs = {} if config.mask_name == "sliding_window": mask_kwargs["window_size"] = 128 # Default window size cute_mask_fn, pytorch_mask_fn = get_mask_pair( config.mask_name, seqlen_q=config.seqlen_q, seqlen_k=config.seqlen_k, **mask_kwargs, ) # For document masking, create wrapper that captures doc_ids if config.mask_name == "document": # PyTorch wrapper def pytorch_mask_fn(b, h, q, kv): return flex_document_mask(b, h, q, kv, doc_ids) # CuTe wrapper - reuse cute_document_mask with aux_tensors cute_mask_fn = cute_document_mask result = run_benchmark(config, pytorch_mask_fn, cute_mask_fn) results.append(result) except Exception as e: print(f"Failed to run config {config}: {e}") results.append( BenchmarkResult( config=config, cute_time_ms=None, pytorch_time_ms=None, error_message=str(e), ) ) finally: torch.cuda.empty_cache() torch._dynamo.reset() print_results(results) if __name__ == "__main__": main() ================================================ FILE: tests/cute/benchmark_mask_mod.py ================================================ """ FlashAttention benchmarking script with Flex Attention-style mask mod support and varlen sequences. """ from dataclasses import dataclass import math from typing import Any, Dict, Optional, Tuple import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack import numpy as np import torch from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 from mask_mod_definitions import ( get_mask_pair, random_doc_id_tensor, ) from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) from flash_attn.cute.compute_block_sparsity import compute_block_sparsity @dataclass class BenchmarkConfig: """Benchmark configuration""" # Model parameters headdim: int headdim_v: int nheads: int nheads_kv: int dtype: torch.dtype # Sequence parameters batch_size: int = 2 seqlen_q: int = 8192 seqlen_k: int = 8192 # Varlen parameters use_varlen: bool = False min_seqlen_q: Optional[int] = None # If None, use seqlen_q // 2 max_seqlen_q: Optional[int] = None # If None, use seqlen_q min_seqlen_k: Optional[int] = None # If None, use seqlen_k // 2 max_seqlen_k: Optional[int] = None # If None, use seqlen_k # Mask parameters use_mask_mod: bool = True mask_mod_name: str = "causal" has_aux_tensors: bool = mask_mod_name == "document" # Sliding window parameter (used when mask_mod_name == "sliding_window") window_size: int = 128 # Attention parameters causal: bool = False is_local: bool = False window_left: Optional[int] = 128 # For base Flash Attention local window_right: Optional[int] = 0 # For base Flash Attention local softcap: Optional[float] = None use_learnable_sink: bool = False # Kernel configuration tile_m: int = 128 tile_n: int = 128 num_stages: int = 2 num_threads: int = 384 intra_wg_overlap: bool = True mma_pv_is_rs: bool = True # Benchmark parameters warmup_iters: int = 10 benchmark_iters: int = 25 verbose: bool = False seed: int = 42 class FlashAttentionBenchmark: def __init__(self, config: BenchmarkConfig): self.config = config torch.manual_seed(config.seed) np.random.seed(config.seed) # Verify SM90 compute capability compute_capability = torch.cuda.get_device_capability() assert compute_capability >= (9, 0), ( f"Requires SM90+, got SM{compute_capability[0]}{compute_capability[1]}" ) # causal overrides use_mask_mod if config.causal: config.use_mask_mod = False if config.use_mask_mod: self.mask_mod_cute, self.mask_mod_flex = get_mask_pair( config.mask_mod_name, seqlen_q=config.seqlen_q, seqlen_k=config.seqlen_k, window_size=config.window_size, ) else: self.mask_mod_cute = None self.mask_mod_flex = None self._validate_config() def _validate_config(self): config = self.config assert config.headdim <= 256, "headdim must be <= 256" assert config.headdim_v <= 256, "headdim_v must be <= 256" assert config.nheads % config.nheads_kv == 0, "nheads must be divisible by nheads_kv" alignment = 16 // config.dtype.itemsize assert config.headdim % alignment == 0, f"headdim must be divisible by {alignment}" assert config.headdim_v % alignment == 0, f"headdim_v must be divisible by {alignment}" # Validate is_local configuration if config.is_local: assert config.window_left is not None or config.window_right is not None, ( "When is_local=True, at least one of window_left or window_right must be set" ) assert not config.use_mask_mod, ( "Cannot use both is_local and use_mask_mod simultaneously" ) assert not config.causal, "Cannot use both is_local and causal simultaneously" # Validate mask_mod configuration if config.use_mask_mod and config.mask_mod_name == "sliding_window": assert config.window_size > 0, ( "window_size must be positive when using sliding_window mask" ) def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tuple[torch.Tensor, int]: """Generate random sequence lengths and compute cumulative lengths.""" seqlens = torch.randint( min_len, max_len + 1, (self.config.batch_size,), dtype=torch.int32, device="cuda" ) cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device="cuda"), torch.cumsum(seqlens, dtype=torch.int32, dim=0), ] ) total_tokens = cu_seqlens[-1].item() return cu_seqlens, total_tokens def _create_tensors(self) -> Dict[str, torch.Tensor]: config = self.config device = "cuda" if config.use_varlen: # Set defaults for varlen range min_q = config.min_seqlen_q if config.min_seqlen_q is not None else config.seqlen_q // 2 max_q = config.max_seqlen_q if config.max_seqlen_q is not None else config.seqlen_q min_k = config.min_seqlen_k if config.min_seqlen_k is not None else config.seqlen_k // 2 max_k = config.max_seqlen_k if config.max_seqlen_k is not None else config.seqlen_k # Generate cu_seqlens cu_seqlens_q, total_q = self._generate_varlen_seqlens(min_q, max_q) cu_seqlens_k, total_k = self._generate_varlen_seqlens(min_k, max_k) # Varlen shape: (total_tokens, nheads, headdim) q = torch.randn( total_q, config.nheads, config.headdim, dtype=config.dtype, device=device ) k = torch.randn( total_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device ) v = torch.randn( total_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device ) out = torch.empty( total_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device ) lse = torch.empty(config.nheads, total_q, dtype=torch.float32, device=device) tensors = { "q": q.contiguous(), "k": k.contiguous(), "v": v.contiguous(), "out": out.contiguous(), "lse": lse.contiguous(), "cu_seqlens_q": cu_seqlens_q.contiguous(), "cu_seqlens_k": cu_seqlens_k.contiguous(), } if config.verbose: print(f"Varlen: total_q={total_q}, total_k={total_k}") print(f"Q seqlens: {cu_seqlens_q[1:] - cu_seqlens_q[:-1]}") print(f"K seqlens: {cu_seqlens_k[1:] - cu_seqlens_k[:-1]}") else: # Standard shape: (batch, seqlen, nheads, headdim) q = torch.randn( config.batch_size, config.seqlen_q, config.nheads, config.headdim, dtype=config.dtype, device=device, ) k = torch.randn( config.batch_size, config.seqlen_k, config.nheads_kv, config.headdim, dtype=config.dtype, device=device, ) v = torch.randn( config.batch_size, config.seqlen_k, config.nheads_kv, config.headdim_v, dtype=config.dtype, device=device, ) out = torch.empty( config.batch_size, config.seqlen_q, config.nheads, config.headdim_v, dtype=config.dtype, device=device, ) lse = torch.empty( config.batch_size, config.nheads, config.seqlen_q, dtype=torch.float32, device=device, ) tensors = { "q": q.contiguous(), "k": k.contiguous(), "v": v.contiguous(), "out": out.contiguous(), "lse": lse.contiguous(), } if config.use_learnable_sink: learnable_sink = torch.rand(config.nheads, dtype=torch.bfloat16, device=device) tensors["learnable_sink"] = learnable_sink.contiguous() # Compute block sparsity when using mask_mod if config.use_mask_mod: if config.mask_mod_name == "document": doc_id = random_doc_id_tensor( config.batch_size, config.nheads, config.seqlen_q, device=device ) tensors["aux_tensors"] = [doc_id.contiguous()] _, blocksparse_torch_tensors = compute_block_sparsity( tile_m=self.config.tile_m, tile_n=self.config.tile_n, batch_size=self.config.batch_size, num_heads=self.config.nheads, seqlen_q=self.config.seqlen_q, seqlen_k=self.config.seqlen_k, mask_mod=self.mask_mod_cute, device=device, cu_seqlens_q=tensors.get("cu_seqlens_q"), cu_seqlens_k=tensors.get("cu_seqlens_k"), aux_tensors=tensors.get("aux_tensors"), ) if blocksparse_torch_tensors is not None: tensors["block_sparse_tensors"] = blocksparse_torch_tensors if config.verbose: total_full = blocksparse_torch_tensors.full_block_cnt.sum().item() total_partial = blocksparse_torch_tensors.mask_block_cnt.sum().item() if config.use_varlen: # Compute max possible blocks across all sequences max_blocks = 0 for i in range(config.batch_size): seq_len_q = ( tensors["cu_seqlens_q"][i + 1] - tensors["cu_seqlens_q"][i] ).item() seq_len_k = ( tensors["cu_seqlens_k"][i + 1] - tensors["cu_seqlens_k"][i] ).item() n_blocks_q = (seq_len_q + config.tile_m - 1) // config.tile_m n_blocks_k = (seq_len_k + config.tile_n - 1) // config.tile_n max_blocks += n_blocks_q * n_blocks_k * config.nheads else: n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m max_blocks = n_blocks_k * n_blocks_q * config.nheads * config.batch_size skipped = max_blocks - total_full - total_partial print( f"Block stats: Full={total_full}, Partial={total_partial}, " f"Skipped={skipped}/{max_blocks}" ) return tensors def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[Any, tuple]: config = self.config dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, } cute_dtype = dtype_map[config.dtype] qhead_per_kvhead = config.nheads // config.nheads_kv kernel = FlashAttentionForwardSm90( cute_dtype, config.headdim, config.headdim_v, qhead_per_kvhead, is_causal=config.causal, is_local=config.is_local, pack_gqa=False, tile_m=config.tile_m, tile_n=config.tile_n, num_stages=config.num_stages, num_threads=config.num_threads, intra_wg_overlap=config.intra_wg_overlap, mma_pv_is_rs=config.mma_pv_is_rs, mask_mod=self.mask_mod_cute, Q_in_regs=False, has_aux_tensors=config.has_aux_tensors, ) softmax_scale = 1.0 / math.sqrt(config.headdim) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Convert tensors to cute q_cute = from_dlpack(tensors["q"].detach(), assumed_align=16).mark_layout_dynamic( leading_dim=tensors["q"].ndim - 1 ) k_cute = from_dlpack(tensors["k"].detach(), assumed_align=16).mark_layout_dynamic( leading_dim=tensors["k"].ndim - 1 ) v_cute = from_dlpack(tensors["v"].detach(), assumed_align=16).mark_layout_dynamic( leading_dim=tensors["v"].ndim - 1 ) out_cute = from_dlpack(tensors["out"].detach(), assumed_align=16).mark_layout_dynamic( leading_dim=tensors["out"].ndim - 1 ) lse_cute = from_dlpack(tensors["lse"].detach(), assumed_align=4).mark_layout_dynamic( leading_dim=tensors["lse"].ndim - 1 ) # Varlen tensors cu_seqlens_q_cute = ( from_dlpack(tensors["cu_seqlens_q"].detach(), assumed_align=4).mark_layout_dynamic( leading_dim=0 ) if "cu_seqlens_q" in tensors else None ) cu_seqlens_k_cute = ( from_dlpack(tensors["cu_seqlens_k"].detach(), assumed_align=4).mark_layout_dynamic( leading_dim=0 ) if "cu_seqlens_k" in tensors else None ) learnable_sink_cute = ( from_dlpack(tensors["learnable_sink"].detach(), assumed_align=4).mark_layout_dynamic( leading_dim=0 ) if "learnable_sink" in tensors else None ) blocksparse_tensors_cute = ( to_cute_block_sparse_tensors(tensors["block_sparse_tensors"]) if "block_sparse_tensors" in tensors else None ) if "aux_tensors" in tensors: aux_tensors_cute = [] for i in range(len(tensors["aux_tensors"])): buf = from_dlpack(tensors["aux_tensors"][i].detach(), assumed_align=4) aux_tensors_cute.append(buf.mark_layout_dynamic(leading_dim=2)) else: aux_tensors_cute = None # Window parameters for is_local window_left_cute = ( cutlass.Int32(config.window_left) if config.window_left is not None else None ) window_right_cute = ( cutlass.Int32(config.window_right) if config.window_right is not None else None ) compiled = cute.compile( kernel, q_cute, k_cute, v_cute, out_cute, lse_cute, softmax_scale, current_stream, cu_seqlens_q_cute, cu_seqlens_k_cute, None, # seqused_q None, # seqused_k None, # page_table window_left_cute, window_right_cute, learnable_sink_cute, blocksparse_tensors_cute, aux_tensors_cute, # None, ) args = ( q_cute, k_cute, v_cute, out_cute, lse_cute, softmax_scale, current_stream, cu_seqlens_q_cute, cu_seqlens_k_cute, None, None, None, window_left_cute, window_right_cute, learnable_sink_cute, blocksparse_tensors_cute, aux_tensors_cute, # None, ) return compiled, args def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: config = self.config # Estimate sparsity for known mask patterns if config.is_local: # Local attention with window_left and window_right window_left = config.window_left if config.window_left is not None else 0 window_right = config.window_right if config.window_right is not None else 0 total_window = window_left + window_right + 1 # +1 for current position sparsity_ratio = min(1.0, total_window / config.seqlen_k) elif config.use_mask_mod: if config.mask_mod_name in ["identity", "identity_partial"]: sparsity_ratio = 1.0 elif config.mask_mod_name in ["causal", "block_causal"]: sparsity_ratio = 0.5 elif config.mask_mod_name == "sliding_window": # Use configured window size sparsity_ratio = min(1.0, config.window_size / config.seqlen_k) elif config.mask_mod_name == "block_diagonal": block_size = 64 num_blocks = (config.seqlen_k + block_size - 1) // block_size sparsity_ratio = 1.0 / num_blocks if num_blocks > 1 else 1.0 elif config.mask_mod_name == "document": vals = tensors["aux_tensors"][0] val_mask = torch.ones_like(vals, dtype=torch.bool) val_mask[..., 1:] = vals[..., 1:] != vals[..., :-1] total = torch.where(val_mask, vals.square(), 0).sum() sparsity_ratio = total / (config.seqlen_q * config.seqlen_k) else: sparsity_ratio = 1.0 elif config.causal: sparsity_ratio = 0.5 else: sparsity_ratio = 1.0 if config.use_varlen: # Compute FLOPs per sequence and sum total_flops = 0 cu_q = tensors["cu_seqlens_q"] cu_k = tensors["cu_seqlens_k"] for i in range(config.batch_size): seq_len_q = (cu_q[i + 1] - cu_q[i]).item() seq_len_k = (cu_k[i + 1] - cu_k[i]).item() # Adjust sparsity for local attention in varlen case if config.is_local: window_left = config.window_left if config.window_left is not None else 0 window_right = config.window_right if config.window_right is not None else 0 total_window = window_left + window_right + 1 seq_sparsity = min(1.0, total_window / seq_len_k) elif config.use_mask_mod and config.mask_mod_name == "sliding_window": seq_sparsity = min(1.0, config.window_size / seq_len_k) else: seq_sparsity = sparsity_ratio num_cells = int(seq_len_q * seq_len_k * seq_sparsity) if config.headdim == config.headdim_v: flops_this_seq = 4 * config.nheads * num_cells * config.headdim else: flops_this_seq = ( 2 * config.nheads * num_cells * config.headdim + 2 * config.nheads * num_cells * config.headdim_v ) total_flops += flops_this_seq return total_flops else: num_cells = int(config.seqlen_q * config.seqlen_k * sparsity_ratio) if config.headdim == config.headdim_v: flops_per_batch = 4 * config.nheads * num_cells * config.headdim else: flops_per_batch = ( 2 * config.nheads * num_cells * config.headdim + 2 * config.nheads * num_cells * config.headdim_v ) return flops_per_batch * config.batch_size def benchmark(self) -> Dict[str, Any]: config = self.config tensors = self._create_tensors() compiled_kernel, args = self._compile_kernel(tensors) # Warmup for _ in range(config.warmup_iters): compiled_kernel(*args) torch.cuda.synchronize() # Benchmark times = [] for _ in range(config.benchmark_iters): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() compiled_kernel(*args) end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) times_tensor = torch.tensor(times) mean_time = times_tensor.mean().item() std_time = times_tensor.std().item() if len(times) > 1 else 0.0 total_flops = self._calculate_flops(tensors) tflops = total_flops / (mean_time * 1e-3) / 1e12 # Bandwidth calculation bytes_per_element = config.dtype.itemsize if config.use_varlen: total_q = tensors["q"].shape[0] total_k = tensors["k"].shape[0] memory_accessed = ( total_q * config.nheads * config.headdim * bytes_per_element + total_k * config.nheads_kv * config.headdim * bytes_per_element + total_k * config.nheads_kv * config.headdim_v * bytes_per_element + total_q * config.nheads * config.headdim_v * bytes_per_element ) else: memory_accessed = ( config.batch_size * config.seqlen_q * config.nheads * config.headdim * bytes_per_element + config.batch_size * config.seqlen_k * config.nheads_kv * config.headdim * bytes_per_element + config.batch_size * config.seqlen_k * config.nheads_kv * config.headdim_v * bytes_per_element + config.batch_size * config.seqlen_q * config.nheads * config.headdim_v * bytes_per_element ) bandwidth_gbps = memory_accessed / (mean_time * 1e-3) / 1e9 results = { "mean_time_ms": mean_time, "std_time_ms": std_time, "tflops": tflops, "bandwidth_gbps": bandwidth_gbps, } if config.verbose: self._print_results(results) return results def _print_results(self, results: Dict[str, Any]): config = self.config # Basic configuration if config.use_varlen: print( f"Shape: B={config.batch_size} (varlen), HD={config.headdim}, " f"NH={config.nheads}, NKV={config.nheads_kv}" ) else: print( f"Shape: B={config.batch_size}, Q={config.seqlen_q}, K={config.seqlen_k}, " f"HD={config.headdim}, NH={config.nheads}, NKV={config.nheads_kv}" ) # Attention pattern attn_info = [] if config.causal: attn_info.append("causal") if config.is_local: window_info = f"local(L={config.window_left},R={config.window_right})" attn_info.append(window_info) if config.use_mask_mod: if config.mask_mod_name == "sliding_window": attn_info.append(f"mask_mod={config.mask_mod_name}(w={config.window_size})") else: attn_info.append(f"mask_mod={config.mask_mod_name}") if config.use_varlen: attn_info.append("varlen") if attn_info: print(f"Attention: {', '.join(attn_info)}") # Performance metrics print(f"Time: {results['mean_time_ms']:.3f} ± {results['std_time_ms']:.3f} ms") print(f"Throughput: {results['tflops']:.2f} TFLOPS") print(f"Bandwidth: {results['bandwidth_gbps']:.1f} GB/s") if __name__ == "__main__": B = 2 config = BenchmarkConfig( headdim=128, headdim_v=128, nheads=16, nheads_kv=16, dtype=torch.bfloat16, batch_size=B, # batch_size=1, seqlen_q=8192, # seqlen_q=128, seqlen_k=8192, # seqlen_k=192, use_varlen=False, use_mask_mod=False, mask_mod_name="causal", window_size=128, # Configurable window size for mask_mod use_learnable_sink=False, causal=True, is_local=False, verbose=True, ) # Example 2: Base Flash Attention Local # config = BenchmarkConfig( # headdim=64, # headdim_v=64, # nheads=64, # nheads_kv=8, # dtype=torch.bfloat16, # batch_size=2, # seqlen_q=8192, # seqlen_k=8192, # use_varlen=False, # use_mask_mod=False, # causal=False, # is_local=True, # window_left=128, # Left window size for base local attention # window_right=0, # Right window size for base local attention # verbose=True, # ) benchmark = FlashAttentionBenchmark(config) results = benchmark.benchmark() ================================================ FILE: tests/cute/conftest.py ================================================ import os import subprocess import logging import tempfile import json import time from pathlib import Path from getpass import getuser def _get_gpu_ids(): visible = os.environ.get("CUDA_VISIBLE_DEVICES") if visible: return [g.strip() for g in visible.split(",")] try: result = subprocess.run( ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], capture_output=True, text=True, timeout=5, ) if result.returncode == 0: return result.stdout.strip().splitlines() except (FileNotFoundError,): pass logging.warning("Failed to get gpu ids, use default '0'") return ["0"] def pytest_configure(config): tmp = Path(tempfile.gettempdir()) / getuser() / "flash_attention_tests" tmp.mkdir(parents=True, exist_ok=True) worker_id = os.environ.get("PYTEST_XDIST_WORKER") logging.basicConfig( format=config.getini("log_file_format"), filename=str(tmp / f"tests_{worker_id}.log"), level=config.getini("log_file_level"), ) if not worker_id: return worker_num = int(worker_id.replace("gw", "")) # cache gpu_ids, because nvidia-smi is expensive when we launch many workers doing torch initialization # Always elect worker_0 to get gpu_ids. cached_gpu_ids = tmp / "gpu_ids.json" if worker_num == 0: gpu_ids = _get_gpu_ids() with cached_gpu_ids.open(mode="w") as f: json.dump(gpu_ids, f) else: while not cached_gpu_ids.exists(): time.sleep(1) with cached_gpu_ids.open() as f: gpu_ids = json.load(f) os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[worker_num % len(gpu_ids)] def pytest_collection_finish(session): # file_name -> test_name -> counter test_counts: dict[str, dict[str, int]] = {} for item in session.items: funcname = item.function.__name__ parent = test_counts.setdefault(item.parent.name, {}) parent[funcname] = parent.setdefault(funcname, 0) + 1 print(json.dumps(test_counts, indent=2)) ================================================ FILE: tests/cute/mask_mod_definitions.py ================================================ from typing import Callable, Optional import random import math import cutlass import cutlass.cute as cute import torch from flash_attn.cute import utils from flash_attn.cute.block_sparsity import fast_sampling # ============================================================================= # CuTe mask_mod functions (for kernel compilation) # All use signature: (batch, head, m_idx, n_idx, seqlen_info, aux_tensors) # ============================================================================= # ============================================================================= # mask_mod functions that don't use global indices # ============================================================================= @fast_sampling @cute.jit def cute_causal_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors: None, ) -> cute.TensorSSA: offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) return n_idx <= (m_idx + offset_ssa) def get_cute_causal_mask(offset: int): return cute_causal_mask def get_cute_block_causal_mask(offset: int): @fast_sampling @cute.jit def _cute_block_causal_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors: None, ) -> cute.TensorSSA: offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) return n_idx <= (m_idx + offset_ssa) return _cute_block_causal_mask def get_cute_sliding_window_mask(window_left: int, window_right: int, offset: int): @fast_sampling @cute.jit def _cute_sliding_window_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: offset = seqlen_info.seqlen_k - seqlen_info.seqlen_q offset_ssa = utils.scalar_to_ssa(offset, cutlass.Int32) window_left_ssa = utils.scalar_to_ssa(window_left, cutlass.Int32) window_right_ssa = utils.scalar_to_ssa(window_right, cutlass.Int32) center = m_idx + offset_ssa lower = center - window_left_ssa upper = center + window_right_ssa return (n_idx >= lower) & (n_idx <= upper) return _cute_sliding_window_mask @fast_sampling @cute.jit def cute_block_diagonal_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: block_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) return (m_idx // block_size_ssa) == (n_idx // block_size_ssa) @cute.jit def cute_mini_causal_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: tile_size_ssa = utils.scalar_to_ssa(128, cutlass.Int32) m_mod = m_idx % tile_size_ssa n_mod = n_idx % tile_size_ssa return m_mod >= n_mod @fast_sampling @cute.jit def cute_prefix_lm_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" prefix_size_ssa = utils.scalar_to_ssa(512, cutlass.Int32) both_in_prefix = (m_idx < prefix_size_ssa) & (n_idx < prefix_size_ssa) causal_part = m_idx >= n_idx return both_in_prefix | causal_part @cute.jit def cute_dilated_sliding_window_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: """Dilated sliding window: every other position in a 256-position window.""" window_size_ssa = utils.scalar_to_ssa(256, cutlass.Int32) dilation_ssa = utils.scalar_to_ssa(2, cutlass.Int32) in_window = (m_idx >= n_idx) & (m_idx - n_idx < window_size_ssa) dilated = ((m_idx - n_idx) % dilation_ssa) == utils.scalar_to_ssa(0, cutlass.Int32) return in_window & dilated @fast_sampling @cute.jit def cute_document_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors: list, ) -> cute.TensorSSA: doc_id = aux_tensors[0] m_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], m_idx[0]], cutlass.Int32) n_doc = utils.scalar_to_ssa(doc_id[batch[0], head[0], n_idx[0]], cutlass.Int32) return m_doc == n_doc @fast_sampling @cute.jit def cute_ima_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: bias = aux_tensors[0] threshold = utils.scalar_to_ssa(bias[n_idx[0]], cutlass.Int32) return n_idx >= threshold # ============================================================================= # mask_mod functions that use global indices (for use with variable sequence length) # Global indices computed as: m_idx_global = m_idx + seqlen_info.offset_q # n_idx_global = n_idx + seqlen_info.offset_k # ============================================================================= # TODO: Add varlen mask implementations here # ============================================================================= # Eager reference functions (PyTorch/Flex Attention signatures) # ============================================================================= def get_flex_causal_mask(offset: int): def _flex_causal_mask(b, h, q_idx, kv_idx): return kv_idx <= q_idx + offset return _flex_causal_mask def get_flex_block_causal_mask(offset: int): def _flex_block_causal_mask(b, h, q_idx, kv_idx): return kv_idx <= q_idx + offset return _flex_block_causal_mask def get_flex_sliding_window_mask(window_left: int, window_right: int, offset: int): def _flex_sliding_window_mask(b, h, q_idx, kv_idx): center = q_idx + offset lower = center - window_left upper = center + window_right return (kv_idx >= lower) & (kv_idx <= upper) return _flex_sliding_window_mask def flex_block_diagonal_mask(b, h, q_idx, kv_idx): block_size = 128 return (q_idx // block_size) == (kv_idx // block_size) def flex_mini_causal_mask(b, h, q_idx, kv_idx): return (q_idx % 128) >= (kv_idx % 128) def flex_prefix_lm_mask(b, h, q_idx, kv_idx): """Prefix LM mask: first 512 tokens attend bidirectionally, rest use causal masking.""" prefix_size = 512 both_in_prefix = (q_idx < prefix_size) & (kv_idx < prefix_size) causal_part = q_idx >= kv_idx return both_in_prefix | causal_part def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): """Dilated sliding window: every other position in a 256-position window.""" window_size = 256 dilation = 2 in_window = (q_idx >= kv_idx) & (q_idx - kv_idx < window_size) dilated = ((q_idx - kv_idx) % dilation) == 0 return in_window & dilated def flex_document_mask(b, h, q_idx, kv_idx, doc_id): return doc_id[b, h, q_idx] == doc_id[b, h, kv_idx] def flex_ima_mask(b, h, q_idx, kv_idx, bias): return kv_idx >= bias[kv_idx] # ============================================================================= # Utility functions # ============================================================================= def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): """Generate synthetic document ids shared across heads.""" doc_ids_tensor = torch.zeros(batch, nheads, seqlen_q, dtype=torch.int32, device=device) for b in range(batch): N = seqlen_q max_segments = max(1, math.ceil(math.sqrt(max(N // 4, 1)))) n = random.randint(1, max_segments) n = min(n, N) cuts = sorted(random.sample(range(1, N), n - 1)) lengths = [b - a for a, b in zip((0, *cuts), (*cuts, N))] base_doc_ids = torch.repeat_interleave( torch.arange(len(lengths), device=device, dtype=torch.int32), torch.tensor(lengths, device=device, dtype=torch.int32), ) for h in range(nheads): doc_ids_tensor[b, h, :] = base_doc_ids return doc_ids_tensor # ============================================================================= # Mask registry and factory functions # ============================================================================= STATIC_MASKS = { "block_diagonal": (cute_block_diagonal_mask, flex_block_diagonal_mask), "mini_causal": (cute_mini_causal_mask, flex_mini_causal_mask), "prefix_lm": (cute_prefix_lm_mask, flex_prefix_lm_mask), "dilated_sliding_window": ( cute_dilated_sliding_window_mask, flex_dilated_sliding_window_mask, ), "document": (cute_document_mask, flex_document_mask), "ima": (cute_ima_mask, flex_ima_mask), } PARAMETERIZED_MASK_FACTORIES = { "causal": (get_cute_causal_mask, get_flex_causal_mask), "block_causal": (get_cute_block_causal_mask, get_flex_block_causal_mask), "sliding_window": (get_cute_sliding_window_mask, get_flex_sliding_window_mask), } def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=None): """Get (cute_mask, flex_mask) pair for the given mask name. For static masks, seqlen info is not needed. For parameterized masks, seqlen_q and seqlen_k are required. """ if mask_name in STATIC_MASKS: return STATIC_MASKS[mask_name] if mask_name not in PARAMETERIZED_MASK_FACTORIES: raise ValueError(f"Unknown mask: {mask_name}") if seqlen_q is None or seqlen_k is None: raise ValueError( f"Parameterized mask '{mask_name}' requires seqlen_q and seqlen_k" ) cute_factory, flex_factory = PARAMETERIZED_MASK_FACTORIES[mask_name] offset = seqlen_k - seqlen_q if mask_name == "sliding_window": if window_size is None: raise ValueError("sliding_window mask requires window_size parameter") cute_mask = cute_factory(window_size, window_size, offset) flex_mask = flex_factory(window_size, window_size, offset) else: cute_mask = cute_factory(offset) flex_mask = flex_factory(offset) return cute_mask, flex_mask if __name__ == "__main__": doc_ids = random_doc_id_tensor(1, 2, 128) print(f"{doc_ids = }") ================================================ FILE: tests/cute/score_mod_definitions.py ================================================ import torch import cutlass import cutlass.cute as cute from cutlass._mlir.dialects import math as mlir_math import operator # ============================================================================= # Score_mod functions that don't use global indices # All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) # ============================================================================= @cute.jit def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa @cute.jit def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa @cute.jit def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask = operator.ge(q_idx, kv_idx) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): mask = cute.make_rmem_tensor(kv_idx.shape, dtype=cutlass.Boolean) kv_idx0 = kv_idx[0] q_idx0 = q_idx[0] for i in cutlass.range_constexpr(cute.size(mask.shape)): mask[i] = q_idx0 >= kv_idx0 + i mask_ssa = mask.load() return cute.where(mask_ssa, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) return tSrS_ssa + abs_diff.to(cutlass.Float32) @cute.jit def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): q_idx0 = q_idx[0] kv_idx0 = kv_idx[0] diff0 = q_idx0 - kv_idx0 abs_diff = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): diffi = diff0 - i abs_diff[i] = mlir_math.absi(diffi) return tSrS_ssa + abs_diff.load().to(cutlass.Float32) @cute.jit def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) scaled = abs_diff * cute.full_like(abs_diff, 2) return tSrS_ssa + scaled.to(cutlass.Float32) @cute.jit def score_mod_rel_bias_x2_vectorized( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): q_idx0 = q_idx[0] kv_idx0 = kv_idx[0] diff0 = q_idx0 - kv_idx0 abs_diff_x2 = cute.make_rmem_tensor(kv_idx.shape, dtype=diff0.dtype) for i in cutlass.range_constexpr(cute.size(kv_idx.shape)): diffi = diff0 - i abs_diff_x2[i] = mlir_math.absi(diffi) * 2 return tSrS_ssa + abs_diff_x2.load().to(cutlass.Float32) @cute.jit def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return tSrS_ssa * cute.full_like(tSrS_ssa, 2) score_mod_times_two_vectorized = score_mod_times_two @cute.jit def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): score = tSrS_ssa.to(cutlass.Float32) slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) slope = cute.math.exp2( slope_exp.to(cutlass.Float32) * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) ) diff = q_idx - kv_idx abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype).to(cutlass.Float32) return score - slope * abs_diff @cute.jit def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): score = tSrS_ssa.to(cutlass.Float32) slope_exp = (h_idx + cute.full_like(h_idx, 1)) * cute.full_like(h_idx, -8) slope = cute.math.exp2( slope_exp.to(cutlass.Float32) * cute.full_like(score, 0.125 * 0.6931471805599453 * 1.4426950408889634) ) diff0 = q_idx[0] - kv_idx[0] abs_diff = cute.make_rmem_tensor(kv_idx.shape, diff0.dtype) for i in cutlass.range_constexpr(cute.size(abs_diff.shape)): diffi = diff0 - i abs_diff[i] = mlir_math.absi(diffi) return score - slope * abs_diff.load().to(cutlass.Float32) @cute.jit def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx abs_diff = cute.TensorSSA(mlir_math.absi(diff), diff.shape, diff.dtype) mask = operator.le(abs_diff, cute.full_like(abs_diff, 256)) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): q_block = q_idx // 64 kv_block = kv_idx // 64 mask = operator.eq(q_block, kv_block) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): diff = q_idx - kv_idx mask = operator.ge(diff, cute.full_like(diff, 0)) return cute.where(mask, tSrS_ssa, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): batch_bias = aux_tensors[0] dtype = batch_bias.element_type b_frag = cute.make_fragment(1, cutlass.Int32) b_frag.store(b_idx) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = batch_bias[b_frag[0]] bias_val = (bias_frag.load()).to(cutlass.Float32) return tSrS_ssa + bias_val @cute.jit def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): batch_bias = aux_tensors[0] dtype = batch_bias.element_type b_idx0 = b_idx[0] bias_frag = cute.make_rmem_tensor(1, dtype) bias_frag[0] = batch_bias[b_idx0] bias_val = (bias_frag.load()).to(cutlass.Float32) return tSrS_ssa + bias_val @cute.jit def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): head_bias = aux_tensors[0] pos_bias = aux_tensors[1] dtype = head_bias.element_type h_frag = cute.make_fragment(1, cutlass.Int32) h_frag.store(h_idx) head_val_frag = cute.make_fragment(1, dtype) head_val_frag[0] = head_bias[h_frag[0]] head_val = (head_val_frag.load()).to(cutlass.Float32) q_frag = cute.make_fragment(1, cutlass.Int32) q_frag.store(q_idx) pos_val_frag = cute.make_fragment(1, dtype) pos_val_frag[0] = pos_bias[q_frag[0]] pos_val = (pos_val_frag.load()).to(cutlass.Float32) return tSrS_ssa + head_val + pos_val @cute.jit def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): head_bias = aux_tensors[0] pos_bias = aux_tensors[1] dtype = head_bias.element_type head_val_frag = cute.make_fragment(1, dtype) head_val_frag[0] = head_bias[h_idx[0]] head_val = (head_val_frag.load()).to(cutlass.Float32) pos_val_frag = cute.make_fragment(1, dtype) pos_val_frag[0] = pos_bias[q_idx[0]] pos_val = (pos_val_frag.load()).to(cutlass.Float32) return tSrS_ssa + head_val + pos_val # ============================================================================= # Score_mod functions that use global indices # All use signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) # Global indices computed as: q_idx_global = q_idx + seqlen_info.offset_q (and similarly for kv) # ============================================================================= @cute.jit def score_mod_global_kv_bias( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Per-token bias using global kv index.""" offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) @cute.jit def score_mod_global_q_bias( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Per-token bias using global q index.""" offset_q = seqlen_info.offset_q q_idx_global = q_idx + offset_q token_bias = aux_tensors[0] dtype = token_bias.element_type q_frag = cute.make_fragment(1, cutlass.Int32) q_frag.store(q_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[q_frag[0]] return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) @cute.jit def score_mod_global_rel_plus_kv_bias( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Relative position (logical) + per-token bias (global kv).""" offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type rel_pos = q_idx - kv_idx rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) @cute.jit def score_mod_global_q_and_kv_bias( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Both q and kv global indices.""" offset_q = seqlen_info.offset_q q_idx_global = q_idx + offset_q offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k q_bias = aux_tensors[0] kv_bias = aux_tensors[1] dtype = q_bias.element_type q_frag = cute.make_fragment(1, cutlass.Int32) q_frag.store(q_idx_global) q_bias_frag = cute.make_fragment(1, dtype) q_bias_frag[0] = q_bias[q_frag[0]] kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) kv_bias_frag = cute.make_fragment(1, dtype) kv_bias_frag[0] = kv_bias[kv_frag[0]] return ( tSrS_ssa + (q_bias_frag.load()).to(cutlass.Float32) + (kv_bias_frag.load()).to(cutlass.Float32) ) @cute.jit def score_mod_global_logical_rel_plus_kv_bias( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Logical relative + global-indexed per-token bias.""" offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type rel_pos = q_idx - kv_idx rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.01) kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] return tSrS_ssa + rel_bias + (bias_frag.load()).to(cutlass.Float32) # "Stress tests" - score_mods with complex global index usage @cute.jit def score_mod_stress_complex_arithmetic( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """All indices in complex arithmetic.""" offset_q = seqlen_info.offset_q q_idx_global = q_idx + offset_q bias = aux_tensors[0] dtype = bias.element_type # Use absolute value instead of squaring to avoid overflow with large sequences rel_pos = q_idx - kv_idx rel_pos_abs = cute.TensorSSA(mlir_math.absi(rel_pos), rel_pos.shape, rel_pos.dtype) rel_bias = rel_pos_abs.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) q_frag = cute.make_fragment(1, cutlass.Int32) q_frag.store(q_idx_global) bias_q_frag = cute.make_fragment(1, dtype) bias_q_frag[0] = bias[q_frag[0]] bias_q = (bias_q_frag.load()).to(cutlass.Float32) scale = (b_idx + cute.full_like(b_idx, 1)) * (h_idx + cute.full_like(h_idx, 1)) scale_f32 = scale.to(cutlass.Float32) * 0.001 result = tSrS_ssa + rel_bias + bias_q * scale_f32 return result @cute.jit def score_mod_stress_conditional_mask( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Conditional masking with global vs logical.""" offset_q = seqlen_info.offset_q q_idx_global = q_idx + offset_q offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] bias_val = (bias_frag.load()).to(cutlass.Float32) is_causal = operator.ge(q_idx, kv_idx) global_diff = q_idx_global - kv_idx_global is_nearby = operator.le( cute.TensorSSA(mlir_math.absi(global_diff), global_diff.shape, global_diff.dtype), cute.full_like(global_diff, 512), ) both_conditions = is_causal & is_nearby return cute.where(both_conditions, tSrS_ssa + bias_val, cute.full_like(tSrS_ssa, float("-inf"))) @cute.jit def score_mod_stress_multi_buffer( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Multiple aux tensors with different indexing.""" offset_q = seqlen_info.offset_q q_idx_global = q_idx + offset_q offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k batch_bias = aux_tensors[0] head_scale = aux_tensors[1] q_pos_bias = aux_tensors[2] kv_pos_bias = aux_tensors[3] rel_pos_scale = aux_tensors[4] dtype = batch_bias.element_type b_frag = cute.make_fragment(1, cutlass.Int32) b_frag.store(b_idx) bb_frag = cute.make_fragment(1, dtype) bb_frag[0] = batch_bias[b_frag[0]] bb_val = (bb_frag.load()).to(cutlass.Float32) h_frag = cute.make_fragment(1, cutlass.Int32) h_frag.store(h_idx) hs_frag = cute.make_fragment(1, dtype) hs_frag[0] = head_scale[h_frag[0]] hs_val = (hs_frag.load()).to(cutlass.Float32) qg_frag = cute.make_fragment(1, cutlass.Int32) qg_frag.store(q_idx_global) qpb_frag = cute.make_fragment(1, dtype) qpb_frag[0] = q_pos_bias[qg_frag[0]] qpb_val = (qpb_frag.load()).to(cutlass.Float32) kvg_frag = cute.make_fragment(1, cutlass.Int32) kvg_frag.store(kv_idx_global) kvpb_frag = cute.make_fragment(1, dtype) kvpb_frag[0] = kv_pos_bias[kvg_frag[0]] kvpb_val = (kvpb_frag.load()).to(cutlass.Float32) rel_idx = q_idx - kv_idx + cute.full_like(q_idx, 512) rel_idx_clamped = cute.where( operator.lt(rel_idx, cute.full_like(rel_idx, 0)), cute.full_like(rel_idx, 0), rel_idx ) rel_idx_clamped = cute.where( operator.gt(rel_idx_clamped, cute.full_like(rel_idx_clamped, 1024)), cute.full_like(rel_idx_clamped, 1024), rel_idx_clamped, ) ri_frag = cute.make_fragment(1, cutlass.Int32) ri_frag.store(rel_idx_clamped) rps_frag = cute.make_fragment(1, dtype) rps_frag[0] = rel_pos_scale[ri_frag[0]] rps_val = (rps_frag.load()).to(cutlass.Float32) return tSrS_ssa * hs_val + bb_val + qpb_val + kvpb_val + rps_val * cute.full_like(tSrS_ssa, 0.1) @cute.jit def score_mod_stress_global_offset( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """Verify global - logical = offset.""" offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] return tSrS_ssa + (bias_frag.load()).to(cutlass.Float32) @cute.jit def score_mod_stress_xor_pattern( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): """XOR-based pattern using index bits.""" offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k token_bias = aux_tensors[0] dtype = token_bias.element_type xor_logical = q_idx ^ kv_idx pattern_logical = xor_logical & cute.full_like(xor_logical, 0xFF) pattern_bias = pattern_logical.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) kv_frag = cute.make_fragment(1, cutlass.Int32) kv_frag.store(kv_idx_global) bias_frag = cute.make_fragment(1, dtype) bias_frag[0] = token_bias[kv_frag[0]] return ( tSrS_ssa + pattern_bias + (bias_frag.load()).to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.1) ) @cute.jit def score_mod_debug_global_idx( tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors ): # Don't read from aux_tensors at all - just add the global index as bias offset_k = seqlen_info.offset_k kv_idx_global = kv_idx + offset_k bias = kv_idx_global.to(cutlass.Float32) * cute.full_like(tSrS_ssa, 0.001) return tSrS_ssa + bias # ============================================================================= # Eager reference functions # ============================================================================= def identity_eager(score, b, h, q_idx, kv_idx): return score def causal_eager(score, b, h, q_idx, kv_idx): return torch.where(q_idx >= kv_idx, score, float("-inf")) def rel_bias_eager(score, b, h, q_idx, kv_idx): return score + torch.abs(q_idx - kv_idx) def rel_bias_x2_eager(score, b, h, q_idx, kv_idx): return score + 2 * torch.abs(q_idx - kv_idx) def times_two_eager(score, b, h, q_idx, kv_idx): return score * 2 def alibi_eager(score, b, h, q_idx, kv_idx): slope = 2 ** (-8 * (h + 1) / 8) return score - slope * torch.abs(q_idx - kv_idx) def sliding_window_eager(score, b, h, q_idx, kv_idx): return torch.where(torch.abs(q_idx - kv_idx) <= 256, score, float("-inf")) def block_diagonal_eager(score, b, h, q_idx, kv_idx): return torch.where(q_idx // 64 == kv_idx // 64, score, float("-inf")) def causal_v2_eager(score, b, h, q_idx, kv_idx): return torch.where(q_idx - kv_idx >= 0, score, float("-inf")) def batch_bias_factory(bias_tensor): def mod(score, b, h, q_idx, kv_idx): return score + bias_tensor[b] return mod def dual_buffer_factory(head_bias, pos_bias): def mod(score, b, h, q_idx, kv_idx): return score + head_bias[h] + pos_bias[q_idx] return mod def packed_kv_bias_factory(bias_tensor, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): # Calculate valid length for this sequence start = cu_seqlens_k[b] seq_len = cu_seqlens_k[b+1] - start # Clamp kv_idx. safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) return score + bias_tensor[start + safe_kv_idx] return mod def packed_q_bias_factory(bias_tensor, cu_seqlens_q): def mod(score, b, h, q_idx, kv_idx): start = cu_seqlens_q[b] seq_len = cu_seqlens_q[b+1] - start # Clamp q_idx safe_q_idx = torch.clamp(q_idx, max=seq_len - 1) return score + bias_tensor[start + safe_q_idx] return mod def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): start = cu_seqlens_k[b] seq_len = cu_seqlens_k[b+1] - start # Clamp kv_idx safe_kv_idx = torch.clamp(kv_idx, max=seq_len - 1) rel_bias = torch.abs(q_idx - kv_idx).float() * 0.1 return score + rel_bias + bias_tensor[start + safe_kv_idx] return mod def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): # Handle Q bounds q_start = cu_seqlens_q[b] q_len = cu_seqlens_q[b+1] - q_start safe_q_idx = torch.clamp(q_idx, max=q_len - 1) # Handle KV bounds kv_start = cu_seqlens_k[b] kv_len = cu_seqlens_k[b+1] - kv_start safe_kv_idx = torch.clamp(kv_idx, max=kv_len - 1) return score + q_bias[q_start + safe_q_idx] + kv_bias[kv_start + safe_kv_idx] return mod def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): rel_bias = torch.abs(q_idx - kv_idx).float() * 0.01 return score + rel_bias + bias_tensor[cu_seqlens_k[b] + kv_idx] return mod def stress_complex_arithmetic_factory(bias, cu_seqlens_q): def mod(score, b, h, q_idx, kv_idx): # Use absolute value instead of squaring to avoid overflow with large sequences rel_pos_abs = torch.abs(q_idx - kv_idx) q_global = cu_seqlens_q[b] + q_idx bias_q = bias[q_global] scale = (b + 1) * (h + 1) * 0.001 rel_bias = rel_pos_abs * 0.001 return score + rel_bias + bias_q * scale return mod def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): kv_global = cu_seqlens_k[b] + kv_idx bias_val = token_bias[kv_global] is_causal = q_idx >= kv_idx q_global = cu_seqlens_q[b] + q_idx global_diff = q_global - kv_global is_nearby = torch.abs(global_diff) <= 512 both_conditions = is_causal & is_nearby return torch.where(both_conditions, score + bias_val, float("-inf")) return mod def stress_multi_buffer_factory( batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale, cu_seqlens_q, cu_seqlens_k, max_rel_pos=512, ): def mod(score, b, h, q_idx, kv_idx): bb_val = batch_bias[b] hs_val = head_scale[h] qpb_val = q_pos_bias[cu_seqlens_q[b] + q_idx] kvpb_val = kv_pos_bias[cu_seqlens_k[b] + kv_idx] rel_idx = (q_idx - kv_idx + max_rel_pos).clamp(0, max_rel_pos * 2) rps_val = rel_pos_scale[rel_idx] return score * hs_val + bb_val + qpb_val + kvpb_val + rps_val * 0.1 return mod def stress_global_offset_factory(token_bias, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): return score + token_bias[cu_seqlens_k[b] + kv_idx] return mod def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k): def mod(score, b, h, q_idx, kv_idx): xor_logical = q_idx ^ kv_idx pattern_bias = (xor_logical & 0xFF).float() * 0.001 kv_global = cu_seqlens_k[b] + kv_idx return score + pattern_bias + token_bias[kv_global] * 0.1 return mod def debug_global_idx_factory(bias, cu_seqlens_k): offsets = cu_seqlens_k.tolist() def mod(score, b, h, q_idx, kv_idx): global_kv = offsets[b] + kv_idx return score + global_kv.float() * 0.001 return mod ================================================ FILE: tests/cute/test_block_sparsity.py ================================================ """Tests for block sparsity computation in flash attention.""" import pytest import torch from torch.nn.attention.flex_attention import create_block_mask from mask_mod_definitions import get_mask_pair from flash_attn.cute.compute_block_sparsity import compute_block_sparsity def _call_compute_block_sparsity( batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name, window_size=None, aux_tensors=None, use_fast_sampling=False, ): """Call compute_block_sparsity and return torch tensors.""" cute_mask, _ = get_mask_pair( mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size ) _, torch_tensors = compute_block_sparsity( tile_m=tile_m, tile_n=tile_n, batch_size=batch_size, num_heads=nheads, seqlen_q=seqlen_q, seqlen_k=seqlen_k, mask_mod=cute_mask, aux_tensors=aux_tensors, device="cuda", use_fast_sampling=use_fast_sampling, ) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = torch_tensors return mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx def _compare_block_sparsity( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, ): """Compare block sparsity against reference, handling boundary block semantics. PyTorch treats OOB regions as masked, so boundary blocks with all in-bounds elements unmasked appear as "partial" in PyTorch but "full" in CuTe. This applies to BOTH boundary m_blocks (OOB q_idx) and boundary n_blocks (OOB kv_idx). """ if not isinstance(mask_block_cnt, torch.Tensor): return False, f"mask_block_cnt is not a tensor: {type(mask_block_cnt)}" n_blocks_q = mask_block_cnt.shape[2] # Identify boundary blocks last_m_block = (seqlen_q - 1) // tile_m last_n_block = (seqlen_k - 1) // tile_n m_is_boundary = seqlen_q % tile_m != 0 n_is_boundary = seqlen_k % tile_n != 0 def is_boundary_n_block(n_block): return n_is_boundary and n_block == last_n_block def is_boundary_m_block(m_block): return m_is_boundary and m_block == last_m_block for b in range(batch_size): for h in range(nheads): for m in range(n_blocks_q): cute_mask_cnt = mask_block_cnt[b, h, m].item() cute_full_cnt = full_block_cnt[b, h, m].item() ref_mask_cnt = mask_block_cnt_ref[b, h, m].item() ref_full_cnt = full_block_cnt_ref[b, h, m].item() cute_mask_set = set(mask_block_idx[b, h, m, :cute_mask_cnt].tolist()) cute_full_set = set(full_block_idx[b, h, m, :cute_full_cnt].tolist()) ref_mask_set = set(mask_block_idx_ref[b, h, m, :ref_mask_cnt].tolist()) ref_full_set = set(full_block_idx_ref[b, h, m, :ref_full_cnt].tolist()) # A block is "boundary-affected" if EITHER the m_block OR n_block is at boundary def is_boundary_affected(n_block): return is_boundary_m_block(m) or is_boundary_n_block(n_block) # Blocks that are full in CuTe but not in ref full_in_cute_not_ref = cute_full_set - ref_full_set for n_block in full_in_cute_not_ref: if not is_boundary_affected(n_block): return False, ( f"Non-boundary block mismatch at [{b},{h},{m}]: " f"n_block {n_block} is full in CuTe but not in ref" ) # Boundary-affected: CuTe says full, ref should say partial if n_block not in ref_mask_set: # Check if ref skipped it entirely (all masked) # This is valid for boundary blocks pass # Blocks that are partial in CuTe but full in ref (would be a bug) partial_in_cute_full_in_ref = cute_mask_set & ref_full_set if partial_in_cute_full_in_ref: return False, ( f"Block mismatch at [{b},{h},{m}]: " f"n_blocks {sorted(partial_in_cute_full_in_ref)} are partial in CuTe but full in ref" ) # Check non-boundary blocks match exactly non_boundary_cute_full = { n for n in cute_full_set if not is_boundary_affected(n) } non_boundary_ref_full = { n for n in ref_full_set if not is_boundary_affected(n) } if non_boundary_cute_full != non_boundary_ref_full: return False, ( f"Non-boundary full block mismatch at [{b},{h},{m}]: " f"CuTe={sorted(non_boundary_cute_full)}, ref={sorted(non_boundary_ref_full)}" ) non_boundary_cute_mask = { n for n in cute_mask_set if not is_boundary_affected(n) } non_boundary_ref_mask = { n for n in ref_mask_set if not is_boundary_affected(n) } if non_boundary_cute_mask != non_boundary_ref_mask: return False, ( f"Non-boundary partial block mismatch at [{b},{h},{m}]: " f"CuTe={sorted(non_boundary_cute_mask)}, ref={sorted(non_boundary_ref_mask)}" ) return True, "" # Test configurations SEQLEN_PAIRS = [ # Small aligned (64, 64), (128, 128), (256, 256), (512, 512), # Rectangular (128, 256), (256, 128), (512, 256), (256, 512), # Large aligned (1024, 1024), (2048, 2048), (4096, 4096), (8192, 8192), # Large unaligned (1000, 1000), (2000, 2000), (4000, 4000), # Edge cases with unaligned seqlens (113, 203), (127, 127), (129, 129), (255, 255), (257, 257), (1023, 1023), (1025, 1025), (2047, 2047), (2049, 2049), ] TILE_SIZES = [ # Standard powers of 2 (32, 32), (64, 64), (128, 128), (256, 256), # Rectangular (32, 64), (64, 32), (64, 128), (128, 64), (128, 256), (256, 128), # Unusual sizes (40, 40), (48, 48), (96, 96), (112, 112), (32, 128), (128, 32), (40, 96), (96, 40), ] @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) @pytest.mark.parametrize("tile_m,tile_n", TILE_SIZES) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("nheads", [1, 4]) @pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal"]) def test_fixed_length_masks( seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name ): """Test fixed-length masks.""" seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( _call_compute_block_sparsity( batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name, use_fast_sampling=False, ) ) _, mask_mod_flex = get_mask_pair(mask_name) block_mask = create_block_mask( mask_mod_flex, B=batch_size, H=nheads, Q_LEN=seqlen_q, KV_LEN=seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) ( _, _, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, *_, ) = block_mask.as_tuple() print("CuTe results:") print(f" mask_block_cnt: {mask_block_cnt}") print(f" full_block_cnt: {full_block_cnt}") print(f" mask_block_idx: {mask_block_idx}") print(f" full_block_idx: {full_block_idx}") print("Torch results:") print(f" mask_block_cnt: {mask_block_cnt_ref}") print(f" full_block_cnt: {full_block_cnt_ref}") print(f" mask_block_idx: {mask_block_idx_ref}") print(f" full_block_idx: {full_block_idx_ref}") all_match, error_msg = _compare_block_sparsity( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, ) assert all_match, f"Mismatch: {error_msg}" @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) @pytest.mark.parametrize( "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] ) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("nheads", [1, 4]) @pytest.mark.parametrize( "mask_name,window_size", [("causal", None), ("sliding_window", 64), ("sliding_window", 256)], ) def test_parameterized_masks( seqlen_q, seqlen_k, tile_m, tile_n, batch_size, nheads, mask_name, window_size ): """Test parameterized masks.""" if mask_name == "sliding_window" and seqlen_q > seqlen_k: pytest.skip("Sliding window not supported for seqlen_q > seqlen_k") mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( _call_compute_block_sparsity( batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name, window_size=window_size, ) ) _, mask_mod_flex = get_mask_pair( mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size ) block_mask = create_block_mask( mask_mod_flex, B=batch_size, H=nheads, Q_LEN=seqlen_q, KV_LEN=seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) ( _, _, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, *_, ) = block_mask.as_tuple() all_match, error_msg = _compare_block_sparsity( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, ) assert all_match, f"Mismatch: {error_msg}" @pytest.mark.parametrize( "seqlen_q,seqlen_k,tile_m,tile_n", [ (1, 1, 64, 64), (63, 63, 64, 64), (65, 65, 64, 64), (129, 129, 128, 128), (100, 200, 64, 128), ], ) def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): """Test edge cases with unaligned dimensions.""" batch_size, nheads = 1, 1 seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( _call_compute_block_sparsity( batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, "causal", ) ) _, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) block_mask = create_block_mask( mask_mod_flex, B=batch_size, H=nheads, Q_LEN=seqlen_q, KV_LEN=seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) ( _, _, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, *_, ) = block_mask.as_tuple() all_match, error_msg = _compare_block_sparsity( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, ) assert all_match, f"Mismatch: {error_msg}" @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS) @pytest.mark.parametrize( "tile_m,tile_n", [(64, 64), (128, 128), (64, 128), (128, 64), (256, 256)] ) @pytest.mark.parametrize("nheads", [1, 4]) @pytest.mark.parametrize("mask_name", ["causal", "block_diagonal"]) def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_name): """Test fast sampling mode (5-point sampling).""" batch_size = 1 seqlen_unaligned = (seqlen_q % tile_m != 0) or (seqlen_k % tile_n != 0) mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = ( _call_compute_block_sparsity( batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, mask_name, use_fast_sampling=True, ) ) _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) block_mask = create_block_mask( mask_mod_flex, B=batch_size, H=nheads, Q_LEN=seqlen_q, KV_LEN=seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) ( _, _, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, *_, ) = block_mask.as_tuple() all_match, error_msg = _compare_block_sparsity( mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, mask_block_cnt_ref, mask_block_idx_ref, full_block_cnt_ref, full_block_idx_ref, batch_size, nheads, seqlen_q, seqlen_k, tile_m, tile_n, ) assert all_match, f"Mismatch: {error_msg}" if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: tests/cute/test_flash_attn.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math import itertools import os import random import re import pytest import torch from einops import rearrange, repeat try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None from flash_attn.cute.testing import ( attention_ref, generate_qkv, generate_random_padding_mask, pad_input, unpad_input, maybe_fake_tensor_mode, is_fake_mode, ) from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, ) # torch FakeTensorMode would enable fast cutedsl kernel compilation without allocating the actual GPU memory or running the kernel # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" # SplitKV is not supported on SM90 IS_SM90 = torch.cuda.get_device_capability()[0] == 9 IS_SM100 = torch.cuda.get_device_capability()[0] == 10 TEST_BWD_ONLY = False VERBOSE = True # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("has_learnable_sink", [False, True]) # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) # @pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 1), (3, 3), (64, 32), (64, 128), (128, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), (4096, 4096), (4224, 4224), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local_enum, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype, ): local = local_enum > 0 if local and causal: pytest.skip() device = "cuda" # set seed seed = 0 random.seed(seed) torch.random.manual_seed(seed) torch.cuda.empty_cache() torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 2 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn( batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = q_ref * softcap / 4 q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) v_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) if has_qv: qv_ref = ( torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if local_enum == 2: window_size = (None, -window_size[1]) elif local_enum == 3: window_size = (-window_size[0], None) if local: print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3) ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() # # if qv is not None: # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref if not is_fake_mode(): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # SplitKV is not supported for hdim >= 192 # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 continue out, lse = flash_attn_func( q, k, v, causal=causal, # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, pack_gqa=pack_gqa, num_splits=num_splits, deterministic=deterministic, ) if is_fake_mode(): # no more flash_attn cutedsl calls for the rest of the loop # skip data-dependent postprocessing continue print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * ( out_pt - out_ref ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and learnable_sink is None # and False and not ((causal or local) and seqlen_k < seqlen_q) ): if d > 192 and IS_SM90: pytest.xfail("hdim > 192 backward: SM90 not supported yet") if d != dv and mha_type != "mha" and IS_SM90: pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) if is_fake_mode(): # no more flash_attn cutedsl calls for the rest of the loop # skip data-dependent postprocessing continue # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad( out_ref, (q_ref, k_ref, v_ref), g ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if VERBOSE: diff_dq = (dq - dq_ref).abs() max_idx = diff_dq.argmax() coords = torch.unravel_index(max_idx, diff_dq.shape) print(f"dQ max diff: {diff_dq.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") diff_dk = (dk - dk_ref).abs() max_idx = diff_dk.argmax() coords = torch.unravel_index(max_idx, diff_dk.shape) print(f"dK max diff: {diff_dk.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") diff_dv = (dv - dv_ref).abs() max_idx = diff_dv.argmax() coords = torch.unravel_index(max_idx, diff_dv.shape) print(f"dV max diff: {diff_dv.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dq - dq_ref).abs().max().item() <= rtol * ( dq_pt - dq_ref ).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dk - dk_ref).abs().max().item() <= rtol * ( dk_pt - dk_ref ).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dv - dv_ref).abs().max().item() <= rtol * ( dv_pt - dv_ref ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) # @pytest.mark.parametrize("has_learnable_sink", [False, True]) @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) # @pytest.mark.parametrize("local_enum", [0]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) # @pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ # (1, 1), # (1, 3), # (2, 1), (511, 1), (3, 513), (64, 128), (128, 128), (256, 256), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (307, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) @pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) # @pytest.mark.parametrize("varlen_mode", ["full"]) @pytest.mark.parametrize( "zero_lengths_q, zero_lengths_k", [ (False, False), (True, False), (False, True), (True, True), ], ) @pytest.mark.parametrize( "unpad_q, unpad_kv", [ (True, True), (False, False), (True, False), (False, True), ], ) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local_enum, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype, varlen_mode, zero_lengths_q, zero_lengths_k, unpad_q, unpad_kv, ): local = local_enum > 0 if local and causal: pytest.skip() if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed seed = seqlen_q + seqlen_k + d + int(causal) * 2 + int(local) random.seed(seed) torch.random.manual_seed(seed) batch_size = 49 if seqlen_q <= 512 else 7 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn( batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) v_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) if has_qv: qv_ref = ( torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if local_enum == 2: window_size = (None, window_size[1]) elif local_enum == 3: window_size = (window_size[0], None) if local: print("window size = ", window_size) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3) ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode=varlen_mode, zero_lengths=zero_lengths_q, ) key_padding_mask = generate_random_padding_mask( seqlen_k, batch_size, device, mode=varlen_mode, zero_lengths=zero_lengths_k, ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) attn_mask = torch.logical_and(padding_mask, another_mask) unused_mask = torch.logical_xor( torch.logical_or(padding_mask, another_mask), attn_mask ) else: attn_mask = padding_mask unused_mask = None return attn_mask, unused_mask query_padding_mask, query_unused_mask = _gen_unused_masks( query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device ) # query_padding_mask[:] = True # query_unused_mask = None key_padding_mask, key_unused_mask = _gen_unused_masks( key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) if causal or local: key_padding_mask = query_padding_mask ( q_unpad, k_unpad, v_unpad, qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q, k, v, qv, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv( q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask, ) if unpad_q: print("cu_seqlens_q = ", cu_seqlens_q) else: print("seqused_q = ", seqused_q) if unpad_kv: print("cu_seqlens_k = ", cu_seqlens_k) else: print("seqused_k = ", seqused_k) q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) if not is_fake_mode(): print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False] # pack_gqa_vals = [False] # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue out_unpad, lse = flash_attn_varlen_func( q_unpad if unpad_q else q, k_unpad if unpad_kv else k, v_unpad if unpad_kv else v, cu_seqlens_q=cu_seqlens_q if unpad_q else None, cu_seqlens_k=cu_seqlens_k if unpad_kv else None, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, seqused_q=seqused_q if not unpad_q else None, seqused_k=seqused_k if not unpad_kv else None, causal=causal, # qv=qv_unpad, # q_descale=q_descale, # k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, deterministic=deterministic, ) out = output_pad_fn(out_unpad) if unpad_q else out_unpad if is_fake_mode(): # no more flash_attn cutedsl calls for the rest of the loop # skip data-dependent postprocessing continue if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * ( out_pt - out_ref ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and not has_learnable_sink # and False ): if d > 192 and IS_SM90: pytest.xfail("hdim > 192 backward: SM90 not supported yet") if d != dv and mha_type != "mha" and IS_SM90: pytest.xfail("SM90 GQA bwd currently requires headdim == headdim_v") g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( # g_unpad, # q_unpad, # k_unpad, # v_unpad, # out_unpad, # lse, # None, # None, # None, # cu_seqlens_q, # cu_seqlens_k, # None, None, # max_seqlen_q, # max_seqlen_k, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( out_unpad, ( q_unpad if unpad_q else q, k_unpad if unpad_kv else k, v_unpad if unpad_kv else v, ), g_unpad ) if is_fake_mode(): # no more flash_attn cutedsl calls for the rest of the loop # skip data-dependent postprocessing continue dq = dq_pad_fn(dq_unpad) if unpad_q else dq_unpad dk = dk_pad_fn(dk_unpad) if unpad_kv else dk_unpad dv = dk_pad_fn(dv_unpad) if unpad_kv else dv_unpad if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) if not unpad_kv: dk.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) dv.masked_fill_(rearrange(~key_padding_mask, "b s -> b s 1 1"), 0.0) if not unpad_q: dq.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 g = output_pad_fn(g_unpad) if unpad_q else g_unpad # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad( out_ref, (q_ref, k_ref, v_ref), g ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if VERBOSE: diff_dq = (dq - dq_ref).abs() max_idx = diff_dq.argmax() coords = torch.unravel_index(max_idx, diff_dq.shape) print(f"dQ max diff: {diff_dq.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dQ={dq[coords].item()}, dQ_ref={dq_ref[coords].item()}") diff_dk = (dk - dk_ref).abs() max_idx = diff_dk.argmax() coords = torch.unravel_index(max_idx, diff_dk.shape) print(f"dK max diff: {diff_dk.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dK={dk[coords].item()}, dK_ref={dk_ref[coords].item()}") diff_dv = (dv - dv_ref).abs() max_idx = diff_dv.argmax() coords = torch.unravel_index(max_idx, diff_dv.shape) print(f"dV max diff: {diff_dv.max().item()}") print(f" at coordinates {tuple(c.item() for c in coords)}: dV={dv[coords].item()}, dV_ref={dv_ref[coords].item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dq - dq_ref).abs().max().item() <= rtol * ( dq_pt - dq_ref ).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dk - dk_ref).abs().max().item() <= rtol * ( dk_pt - dk_ref ).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dv - dv_ref).abs().max().item() <= rtol * ( dv_pt - dv_ref ).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("has_learnable_sink", [False, True]) # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @pytest.mark.parametrize("has_rotary_seqlens", [False]) # @pytest.mark.parametrize("rotary_interleaved", [False, True]) @pytest.mark.parametrize("rotary_interleaved", [True]) # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) # @pytest.mark.parametrize("page_size", [None, 128]) # @pytest.mark.parametrize("page_size", [128]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) # @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), # # (1, 128 * 1024), # # (16, 128 * 1024), # (128, 128), # (256, 512), # To test appending KV with more than 1 block # (2048, 3577), # Enough tile to test persistent scheduler ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, varlen_q, has_batch_idx, has_leftpad, page_size, rotary_fraction, rotary_interleaved, has_rotary_seqlens, seqlen_new_eq_seqlen_q, causal, local, new_kv, has_learnable_sink, mha_type, dtype, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if rotary_fraction == 0.0 and has_rotary_seqlens: pytest.skip() device = "cuda" # set seed seed = 0 random.seed(seed) torch.random.manual_seed(seed) batch_size = 5 # batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # nheads = 1 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [d] if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): # has_qv = d == 64 and dv >= 256 has_qv = False q = ( torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) if has_qv: qv = ( torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) else: qv = None if varlen_q: query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random" ) q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input( q, query_padding_mask ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) qv_unpad = ( rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None ) else: query_padding_mask = None q_unpad = q qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( (None, None) if not local else tuple(random.randrange(0, seqlen_k) for _ in range(2)) ) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None seqlen_new = ( seqlen_q if seqlen_new_eq_seqlen_q else random.randrange(1, seqlen_q + 1) ) cu_seqlens_k_new = None key_new_padding_mask = None if new_kv: k = ( torch.randn( batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) v = ( torch.randn( batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) if varlen_q: # k & v are also varlen key_new_padding_mask = generate_random_padding_mask( seqlen_new, batch_size, device, mode="random" ) k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input( k, key_new_padding_mask ) v_unpad, *rest = unpad_input(v, key_new_padding_mask) else: k_unpad, v_unpad = k, v else: k, v, k_unpad, v_unpad = None, None, None, None if page_size is None: k_cache = ( torch.randn( batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref, ) .to(dtype) .to(dtype_ref) ) v_cache = ( torch.randn( batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref, ) .to(dtype) .to(dtype_ref) ) page_table = None else: ( k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref, ) if not is_fake_mode(): cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( ( seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1 ) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) else: cache_seqlens = torch.ones( batch_size, dtype=torch.int32, device=device, ) if has_leftpad: if not is_fake_mode(): cache_leftpad = torch.cat( [ torch.randint( 0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device, ) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size) ] ) else: cache_leftpad = torch.zeros(batch_size, dtype=torch.int32, device=device) else: cache_leftpad = None if has_batch_idx: if not is_fake_mode(): cache_batch_idx = torch.randperm( batch_size_cache, dtype=torch.int32, device=device )[:batch_size] else: cache_batch_idx = torch.arange( batch_size, dtype=torch.int32, device=device ) else: cache_batch_idx = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if not new_kv: key_padding_mask = arange < cache_seqlens_expanded else: k_new_seqlens = ( key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new ) key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k), ) # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 if rotary_dim > 0: angle = ( torch.rand( seqlen_k if page_size is None else num_blocks * page_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved, ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx] ).clone() v_cache_ref = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx] ).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens, ) k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: k_to_update = k_to_update[indices_k] v_to_update = v_to_update[indices_k] k_cache_ref[update_mask] = k_to_update v_cache_ref[update_mask] = v_to_update k_cache_rep = repeat( k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) v_cache_rep = repeat( v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k ) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, learnable_sink=learnable_sink, attention_chunk=attention_chunk, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, query_padding_mask, key_padding_mask, causal=causal, qv=qv, window_size=window_size, learnable_sink=learnable_sink, attention_chunk=attention_chunk, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None k = k.to(dtype) if k is not None else None v = v.to(dtype) if v is not None else None k_unpad = k_unpad.to(dtype) if k_unpad is not None else None v_unpad = v_unpad.to(dtype) if v_unpad is not None else None qv = qv.to(dtype) if qv is not None else None qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() # num_splits_vals = [1, 0] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] # precompute_metadata_vals = [False, True] precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): # SplitKV not supported on SM90 - skip this iteration if IS_SM90 and num_splits > 1: continue # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, # cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, # cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, # max_seqlen_k_new=seqlen_new, page_size=page_size, # causal=causal, window_size=window_size, attention_chunk=attention_chunk, # num_splits=num_splits # ) # else: # scheduler_metadata = None scheduler_metadata = None # Repeat to test metadata reuse for _ in range(1 if not precompute_metadata else 2): if page_size is None: k_cache.copy_(k_cache_saved) v_cache.copy_(v_cache_saved) else: k_cache_paged.copy_(k_cache_saved) v_cache_paged.copy_(v_cache_saved) # out, lse, *rest = flash_attn_with_kvcache( out, lse, *rest = flash_attn_varlen_func( q if not varlen_q else q_unpad, k_cache if page_size is None else k_cache_paged, v_cache if page_size is None else v_cache_paged, # k if not new_kv or not varlen_q else k_unpad, # v if not new_kv or not varlen_q else v_unpad, # qv=qv if not varlen_q else qv_unpad, # rotary_cos=cos, # rotary_sin=sin, seqused_k=cache_seqlens, # cache_batch_idx=cache_batch_idx, # cache_leftpad=cache_leftpad, page_table=page_table, cu_seqlens_q=cu_seqlens_q, # cu_seqlens_k_new=cu_seqlens_k_new, # rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, learnable_sink=learnable_sink, # attention_chunk=attention_chunk, # rotary_interleaved=rotary_interleaved, # scheduler_metadata=scheduler_metadata, num_splits=num_splits, # return_softmax_lse=True ) if varlen_q: out = output_pad_fn(out) if is_fake_mode(): # no more flash_attn cutedsl calls for the rest of the loop # skip data-dependent postprocessing continue # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if page_size is None: k_cache_select = ( k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] ) v_cache_select = ( v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] ) else: k_cache_select = rearrange( k_cache_paged.to(dtype_ref)[ ( page_table if not has_batch_idx else page_table[cache_batch_idx] ).flatten() ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) v_cache_select = rearrange( v_cache_paged.to(dtype_ref)[ ( page_table if not has_batch_idx else page_table[cache_batch_idx] ).flatten() ], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k].to(dtype_ref) k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: assert torch.equal(v_cache_select, v_cache_ref) else: assert torch.allclose( v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3 ) # breakpoint() # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: if rotary_dim == 0: assert torch.equal(k_cache_select, k_cache_ref) else: # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): # breakpoint() if dtype is not torch.float8_e4m3fn: assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3 ) else: assert torch.allclose( k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1 ) mult = 4 if dtype == torch.float8_e4m3fn else 2 assert (out - out_ref).abs().max().item() <= mult * ( out_pt - out_ref ).abs().max().item() + 1e-5 mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 assert (out - out_ref).abs().mean().item() <= mult_mean * ( out_pt - out_ref ).abs().mean().item() @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, causal, dtype): from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd device = "cuda" torch.random.manual_seed(42) batch_size = 2 nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out, lse = _flash_attn_fwd(q, k, v, causal=causal, return_lse=True) dout = torch.randn_like(out) dq_ref, dk_ref, dv_ref = _flash_attn_bwd(q, k, v, out, dout, lse, causal=causal) dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) dq_out, dk_out, dv_out = _flash_attn_bwd( q, k, v, out, dout, lse, causal=causal, dq=dq, dk=dk, dv=dv ) if is_fake_mode(): return assert dq_out is dq assert dk_out is dk assert dv_out is dv assert torch.allclose(dq, dq_ref, atol=1e-5, rtol=1e-5) assert torch.allclose(dk, dk_ref, atol=1e-5, rtol=1e-5) assert torch.allclose(dv, dv_ref, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128), (256, 256)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype): """Test that gradient flows through the returned LSE tensor.""" device = "cuda" torch.random.manual_seed(42) batch_size = 2 nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out, lse = flash_attn_func(q, k, v, causal=causal, return_lse=True) if is_fake_mode(): return assert lse is not None assert lse.requires_grad # Compute loss = sum(out * g) + sum(lse * dlse_weight) to test gradient flows through both g = torch.randn_like(out) dlse_weight = torch.randn_like(lse) loss = (out * g).sum() + (lse * dlse_weight).sum() dq, dk, dv = torch.autograd.grad(loss, (q, k, v)) # Compare against reference: manually compute what the gradients should be # Reference: standard attention in float q_ref = q.detach().float().requires_grad_() k_ref = k.detach().float().requires_grad_() v_ref = v.detach().float().requires_grad_() # (batch, seqlen_q, nheads, d) -> (batch, nheads, seqlen_q, d) qk = torch.einsum("bshd,bthd->bhst", q_ref, k_ref) / (d ** 0.5) if causal: mask = torch.triu(torch.ones(seqlen_q, seqlen_k, device=device, dtype=torch.bool), diagonal=seqlen_k - seqlen_q + 1) qk = qk.masked_fill(mask, float("-inf")) lse_ref = torch.logsumexp(qk, dim=-1) # (batch, nheads, seqlen_q) p = torch.softmax(qk, dim=-1) # v_ref: (batch, seqlen_k, nheads, d) out_ref = torch.einsum("bhst,bthd->bshd", p, v_ref) loss_ref = (out_ref * g.float()).sum() + (lse_ref * dlse_weight.float()).sum() dq_ref, dk_ref, dv_ref = torch.autograd.grad(loss_ref, (q_ref, k_ref, v_ref)) # Use relaxed tolerances since flash_attn operates in bf16 while reference is float32. # The reference is also not a perfect bf16 simulation (it doesn't reorder ops), so # we use a generous tolerance. print(f"dQ max diff: {(dq.float() - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk.float() - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv.float() - dv_ref).abs().max().item()}") # Absolute tolerance: bf16 has ~0.004-0.02 error for these sizes atol = 0.02 assert (dq.float() - dq_ref).abs().max().item() <= atol, f"dQ error too large" assert (dk.float() - dk_ref).abs().max().item() <= atol, f"dK error too large" assert (dv.float() - dv_ref).abs().max().item() <= atol, f"dV error too large" # Also test: gradient with only dLSE (no dO) out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) loss_lse_only = (lse2 * dlse_weight).sum() dq2, dk2, dv2 = torch.autograd.grad(loss_lse_only, (q, k, v)) q_ref2 = q.detach().float().requires_grad_() k_ref2 = k.detach().float().requires_grad_() qk2 = torch.einsum("bshd,bthd->bhst", q_ref2, k_ref2) / (d ** 0.5) if causal: qk2 = qk2.masked_fill(mask, float("-inf")) lse_ref2 = torch.logsumexp(qk2, dim=-1) loss_ref2 = (lse_ref2 * dlse_weight.float()).sum() dq_ref2, dk_ref2 = torch.autograd.grad(loss_ref2, (q_ref2, k_ref2)) print(f"LSE-only dQ max diff: {(dq2.float() - dq_ref2).abs().max().item()}") print(f"LSE-only dK max diff: {(dk2.float() - dk_ref2).abs().max().item()}") # dV should be zero when only LSE gradient flows (LSE doesn't depend on V) print(f"LSE-only dV max: {dv2.abs().max().item()}") assert dv2.abs().max().item() == 0.0, "dV should be zero when loss depends only on LSE" @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen_q,seqlen_k", [(128, 128)]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype): """Test return_lse=True when LSE is returned but not used in the loss. With set_materialize_grads(False), dlse should be None (not a zero tensor), so no extra zeroing kernel is launched. Gradients should match the standard backward (without return_lse). """ device = "cuda" torch.random.manual_seed(42) batch_size = 2 nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) g = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) # Case 1: return_lse=False (standard path, lse marked non-differentiable) out1, lse1 = flash_attn_func(q, k, v, causal=causal, return_lse=False) if is_fake_mode(): return dq1, dk1, dv1 = torch.autograd.grad(out1, (q, k, v), g) # Case 2: return_lse=True but lse NOT used in loss (dlse should be None) out2, lse2 = flash_attn_func(q, k, v, causal=causal, return_lse=True) dq2, dk2, dv2 = torch.autograd.grad(out2, (q, k, v), g) # Case 3: return_lse=True and lse IS used in loss out3, lse3 = flash_attn_func(q, k, v, causal=causal, return_lse=True) dlse_weight = torch.randn_like(lse3) loss3 = (out3 * g).sum() + (lse3 * dlse_weight).sum() dq3, dk3, dv3 = torch.autograd.grad(loss3, (q, k, v)) # Cases 1 and 2 should produce identical gradients assert torch.equal(dq1, dq2), "dQ should be identical when LSE is unused" assert torch.equal(dk1, dk2), "dK should be identical when LSE is unused" assert torch.equal(dv1, dv2), "dV should be identical when LSE is unused" # Case 3 should differ from case 1 (LSE gradient adds extra contribution to dQ, dK) assert not torch.equal(dq1, dq3), "dQ should differ when LSE gradient is included" assert not torch.equal(dk1, dk3), "dK should differ when LSE gradient is included" # dV should be the same since LSE doesn't depend on V assert torch.equal(dv1, dv3), "dV should be identical since LSE doesn't depend on V" print("Case 1 vs 2 (unused LSE): dQ diff =", (dq1 - dq2).abs().max().item()) print("Case 1 vs 3 (used LSE): dQ diff =", (dq1 - dq3).abs().max().item()) print("Case 1 vs 3 (used LSE): dK diff =", (dk1 - dk3).abs().max().item()) print("Case 1 vs 3 (used LSE): dV diff =", (dv1 - dv3).abs().max().item()) def _generate_block_kvcache( seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref ): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = ( torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) v_cache_paged = ( torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref) .to(dtype) .to(dtype_ref) ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache = rearrange( k_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache = rearrange( v_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks @pytest.mark.parametrize("page_size", [16, 64, 256]) @pytest.mark.parametrize("seqlen_q", [64, 128, 256]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_paged_deepseek(seqlen_q, page_size): """Regression test: paged non-TMA with DeepSeek MLA shape (d=192, dv=128). seqlen_q<=128 triggers q_stage=1, seqlen_q>128 triggers q_stage=2. """ if IS_SM90: pytest.skip("paged KV not supported on SM90") device = "cuda" dtype = torch.bfloat16 d, dv = 192, 128 nheads = 16 nheads_kv = 16 torch.random.manual_seed(0) q = torch.randn(seqlen_q, nheads, d, device=device, dtype=dtype) k = torch.randn(seqlen_q, nheads_kv, d, device=device, dtype=dtype) v = torch.randn(seqlen_q, nheads_kv, dv, device=device, dtype=dtype) cu_seqlens = torch.tensor([0, seqlen_q], dtype=torch.int32, device=device) # Non-paged reference out_ref, _ = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_q, causal=True, ) # Paged num_pages = (seqlen_q + page_size - 1) // page_size k_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, d, device=device, dtype=dtype) v_cache_paged = torch.zeros(num_pages, page_size, nheads_kv, dv, device=device, dtype=dtype) for i in range(seqlen_q): k_cache_paged[i // page_size, i % page_size] = k[i] v_cache_paged[i // page_size, i % page_size] = v[i] page_table = torch.arange(num_pages, dtype=torch.int32, device=device).unsqueeze(0) cache_seqlens = torch.tensor([seqlen_q], dtype=torch.int32, device=device) out, _ = flash_attn_varlen_func( q, k_cache_paged, v_cache_paged, cu_seqlens_q=cu_seqlens, cu_seqlens_k=None, max_seqlen_q=seqlen_q, max_seqlen_k=None, seqused_k=cache_seqlens, page_table=page_table, causal=True, ) if is_fake_mode(): return print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(out, out_ref) @pytest.mark.parametrize("head_dim", [4, 148, 288]) def test_flash_attn_invalid_head_dim(head_dim): device = "cuda" dtype = torch.bfloat16 batch_size, seqlen, nheads = 1, 64, 4 q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype) with pytest.raises(AssertionError, match=re.escape(f"(head_dim, head_dim_v)=({head_dim}, {head_dim}) is not supported on SM")): flash_attn_func(q, k, v) ================================================ FILE: tests/cute/test_flash_attn_combine.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import os import pytest import torch from flash_attn.cute.testing import ( maybe_fake_tensor_mode, is_fake_mode, ) from flash_attn.cute.interface import ( flash_attn_combine, ) USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 def attention_combine_ref(out_partial, lse_partial): """ out_partial: (num_splits, batch_size, seqlen, nheads, d) lse_partial: (num_splits, batch_size, seqlen, nheads) """ lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) scale = torch.where( torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale ) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse def check_combine_results(out, lse, out_ref, lse_ref, dtype): """Check combine kernel output against reference for a single (seqlen, nheads, d) chunk.""" out_pt = out_ref.to(dtype) print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}, " f"Output max diff: {(out - out_ref).abs().max().item()}, " f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) assert ( (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() ) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [11]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_combine(num_splits, seqlen, d, dtype): device = "cuda" # set seed torch.random.manual_seed(1) batch_size = 5 nheads = 16 # batch_size = 1 # nheads = 1 # Create tensors in the expected format: (num_splits, batch_size, seqlen, nheads, d) and (num_splits, batch_size, seqlen, nheads) out_partial = torch.randn( num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32, ).transpose(2, 3)[:num_splits] # To test non-contiguous tensor lse_partial = torch.randn( num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32 ).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor # To test short-circuiting based on num_splits lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") # Test with LSE returned (default behavior) out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, return_lse=True ) if is_fake_mode(): return out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) check_combine_results(out, lse, out_ref, lse_ref, dtype) # Test with LSE not returned out_no_lse, lse_no_lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, return_lse=False ) assert lse_no_lse is None, "LSE should be None when return_lse=False" assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), ( "Output should be the same regardless of return_lse" ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("d", [64, 96, 128, 256]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen", [1, 32, 113, 256, 1024]) # @pytest.mark.parametrize("seqlen", [113]) @pytest.mark.parametrize("num_splits", [2, 5, 17, 55]) # @pytest.mark.parametrize("num_splits", [5]) @pytest.mark.parametrize( "varlen_mode", ["cu_seqlens", "seqused", "cu_seqlens_seqused"], ) # @pytest.mark.parametrize("varlen_mode", ["cu_seqlens"]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, dtype): device = "cuda" torch.random.manual_seed(1) batch_size = 3 nheads = 8 use_cu_seqlens = "cu_seqlens" in varlen_mode use_seqused = "seqused" in varlen_mode # Generate variable-length sequences seqlens = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) # For cu_seqlens+seqused mode, seqused < seqlen (kernel processes fewer tokens) seqused_vals = ( torch.clamp( seqlens - torch.randint(0, max(1, seqlen // 4), (batch_size,), device=device, dtype=torch.int32), min=1, ) if use_cu_seqlens and use_seqused else seqlens ) if use_cu_seqlens: # Packed varlen layout: (num_splits, total_q, nheads, d) total_q = seqlens.sum().item() cu_seqlens_q = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) cu_seqlens_q[1:] = torch.cumsum(seqlens, dim=0) out_partial = torch.randn( num_splits * 2, total_q, nheads, d, device=device, dtype=torch.float32, )[:num_splits] # Non-contiguous in splits dim # lse_partial needs stride(-2)==1 (seqlen dim contiguous) lse_partial = torch.randn( num_splits, nheads, total_q, device=device, dtype=torch.float32 ).transpose(-1, -2) lse_partial[num_splits // 2:, :total_q // 3] = -float("inf") out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, cu_seqlens=cu_seqlens_q, seqused=seqused_vals if use_seqused else None, return_lse=True, ) if is_fake_mode(): return # Reference on full packed tensor out_ref, lse_ref = attention_combine_ref( out_partial.unsqueeze(1), lse_partial.unsqueeze(1) ) out_ref = out_ref.squeeze(0) lse_ref = lse_ref.squeeze(0) # Validate per-batch (only seqused_vals tokens are guaranteed correct) for i in range(batch_size): start = cu_seqlens_q[i].item() sl = seqused_vals[i].item() check_combine_results( out[start:start + sl], lse[start:start + sl], out_ref[start:start + sl], lse_ref[start:start + sl], dtype, ) # Also test return_lse=False out_no_lse, lse_no_lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, cu_seqlens=cu_seqlens_q, seqused=seqused_vals if use_seqused else None, return_lse=False, ) assert lse_no_lse is None # Only compare valid positions (beyond seqused, output is undefined) for i in range(batch_size): start = cu_seqlens_q[i].item() sl = seqused_vals[i].item() assert torch.allclose(out_no_lse[start:start + sl], out[start:start + sl], atol=1e-5, rtol=1e-5) else: # seqused only — batched layout: (num_splits, batch, max_seqlen, nheads, d) max_seqlen = seqlens.max().item() out_partial = torch.randn( num_splits, batch_size, max_seqlen, nheads, d, device=device, dtype=torch.float32, ) # lse_partial needs stride(-2)==1 (seqlen dim contiguous) lse_partial = torch.randn( num_splits, batch_size, nheads, max_seqlen, device=device, dtype=torch.float32, ).transpose(-1, -2) lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") # Zero out / -inf beyond seqused so reference matches kernel for i in range(batch_size): out_partial[:, i, seqlens[i]:] = 0 lse_partial[:, i, seqlens[i]:] = -float("inf") out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, seqused=seqlens, return_lse=True, ) if is_fake_mode(): return out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) # Validate per-batch (only seqused tokens) for i in range(batch_size): sl = seqlens[i].item() check_combine_results( out[i, :sl], lse[i, :sl], out_ref[i, :sl], lse_ref[i, :sl], dtype, ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("d", [64, 128, 256]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen", [32, 113, 256]) # @pytest.mark.parametrize("seqlen", [113]) @pytest.mark.parametrize("num_splits", [2, 5, 17]) # @pytest.mark.parametrize("num_splits", [5]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): """Test that varlen_batch_idx correctly remaps virtual batch indices to real batch indices. varlen_batch_idx maps blockIdx.z (virtual batch) -> real batch index. The kernel reads AND writes using the remapped batch_idx, so with a permutation the output should match running without varlen_batch_idx (each real batch is processed once). We also test with seqused to verify interaction with variable-length sequences. """ device = "cuda" torch.random.manual_seed(42) batch_size = 4 nheads = 8 # Create batched input data out_partial = torch.randn( num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, ) lse_partial = torch.randn( num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, ).transpose(-1, -2) # stride(-2)==1 lse_partial[num_splits // 2:, :batch_size // 2] = -float("inf") # Create a permuted batch index mapping: virtual batch -> real batch perm = torch.tensor([2, 0, 3, 1], device=device, dtype=torch.int32) assert perm.shape[0] == batch_size # Also test with seqused to verify interaction with varlen_batch_idx seqused = torch.randint(1, seqlen + 1, (batch_size,), device=device, dtype=torch.int32) # Zero out / -inf beyond seqused so reference matches kernel for i in range(batch_size): out_partial[:, i, seqused[i]:] = 0 lse_partial[:, i, seqused[i]:] = -float("inf") # Run with varlen_batch_idx and seqused via public API out, lse = flash_attn_combine( out_partial, lse_partial, out_dtype=dtype, seqused=seqused, varlen_batch_idx=perm, return_lse=True, ) if is_fake_mode(): return # Reference: standard combine (no remapping needed since perm is a bijection # and both reads and writes use the remapped batch_idx) out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) # The kernel reads from input[perm[v]] and writes to output[perm[v]], # so the net result is output[b] = combine(input[b]) for all b. for b in range(batch_size): sl = seqused[b].item() check_combine_results( out[b, :sl], lse[b, :sl], out_ref[b, :sl], lse_ref[b, :sl], dtype, ) ================================================ FILE: tests/cute/test_flash_attn_fast.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # Fast subset of test_flash_attn.py for quick iteration. # Covers: causal/noncausal, varlen/not varlen, MHA/GQA, split/not split, fwd+bwd. import os import random import pytest import torch from einops import rearrange from flash_attn.cute.testing import ( attention_ref, generate_random_padding_mask, generate_qkv, maybe_fake_tensor_mode, is_fake_mode, ) from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, flash_attn_combine, ) USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 IS_SM90 = torch.cuda.get_device_capability()[0] == 9 # --------------------------------------------------------------------------- # Forward + backward (non-varlen) # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "gqa"]) @pytest.mark.parametrize("num_splits", [1, 3]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (128, 128), (256, 256), (113, 203), (1024, 1024), ], ) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mha_type, dtype): if IS_SM90 and num_splits > 1: pytest.skip("SM90 fwd doens't support num_splits > 1") device = "cuda" torch.random.manual_seed(0) random.seed(0) torch.cuda.empty_cache() batch_size = 4 nheads = 6 nheads_kv = nheads if mha_type == "mha" else 3 q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() q = q_ref.detach().to(dtype).requires_grad_() k = k_ref.detach().to(dtype).requires_grad_() v = v_ref.detach().to(dtype).requires_grad_() out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) out_pt, _ = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, ) out, lse = flash_attn_func(q, k, v, causal=causal, num_splits=num_splits) if is_fake_mode(): return fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol # Backward (only for non-split, matching d) can_bwd = ( num_splits == 1 and d <= 128 and not (causal and seqlen_k < seqlen_q) ) if IS_SM90 and d == 64 and not causal: can_bwd = False # SM90 d=64 non-causal xfail if not can_bwd: return g = torch.randn_like(out) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + dq_atol assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + dk_atol assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + dv_atol # --------------------------------------------------------------------------- # Forward + backward (varlen with cu_seqlens) # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "gqa"]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [128, 256, 1024]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): """Varlen test with cu_seqlens (packed): equal seqlens so we can compare with non-varlen ref.""" device = "cuda" seed = seqlen + d + int(causal) * 2 torch.random.manual_seed(seed) random.seed(seed) batch_size = 9 nheads = 6 nheads_kv = nheads if mha_type == "mha" else 3 q_ref = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype).to(dtype).requires_grad_() k_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() v_ref = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype).to(dtype).requires_grad_() out_ref, _ = attention_ref(q_ref, k_ref, v_ref, None, None, causal=causal) out_pt, _ = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, upcast=False, reorder_ops=True, ) cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, seqlen, device=device, dtype=torch.int32) q_varlen = rearrange(q_ref.detach(), "b s h d -> (b s) h d").requires_grad_() k_varlen = rearrange(k_ref.detach(), "b s h d -> (b s) h d").requires_grad_() v_varlen = rearrange(v_ref.detach(), "b s h d -> (b s) h d").requires_grad_() out_varlen, lse = flash_attn_varlen_func( q_varlen, k_varlen, v_varlen, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, ) if is_fake_mode(): return out_reshaped = rearrange(out_varlen, "(b s) h d -> b s h d", b=batch_size) fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() assert (out_reshaped - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + fwd_atol # Backward can_bwd = d <= 128 if not can_bwd: return g = torch.randn_like(out_varlen) dq_varlen, dk_varlen, dv_varlen = torch.autograd.grad(out_varlen, (q_varlen, k_varlen, v_varlen), g) assert dq_varlen.isfinite().all(), "dq contains non-finite values" assert dk_varlen.isfinite().all(), "dk contains non-finite values" assert dv_varlen.isfinite().all(), "dv contains non-finite values" assert dq_varlen.abs().max().item() > 0, "dq is all zeros" assert dk_varlen.abs().max().item() > 0, "dk is all zeros" assert dv_varlen.abs().max().item() > 0, "dv is all zeros" # --------------------------------------------------------------------------- # Forward + backward (varlen with padding masks — all unpad combinations) # Covers 4 compile-key-distinct paths: # (unpad_q, unpad_kv) = (T,T): cu_seqlens for both Q and K # (unpad_q, unpad_kv) = (F,F): seqused for both Q and K # (unpad_q, unpad_kv) = (T,F): cu_seqlens_q + seqused_k # (unpad_q, unpad_kv) = (F,T): seqused_q + cu_seqlens_k # --------------------------------------------------------------------------- @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "gqa"]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [128, 256]) @pytest.mark.parametrize( "unpad_q,unpad_kv", [(True, True), (False, False), (True, False), (False, True)], ) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unpad_q, unpad_kv, dtype): """Varlen test with all 4 (unpad_q, unpad_kv) combos: cu_seqlens vs seqused.""" device = "cuda" seed = seqlen + d + int(causal) * 2 + int(unpad_q) * 7 + int(unpad_kv) * 13 torch.random.manual_seed(seed) random.seed(seed) batch_size = 9 nheads = 6 nheads_kv = nheads if mha_type == "mha" else 3 q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) k = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen, nheads_kv, d, device=device, dtype=dtype) q_ref = q.detach().to(dtype).requires_grad_() k_ref = k.detach().to(dtype).requires_grad_() v_ref = v.detach().to(dtype).requires_grad_() query_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") key_padding_mask = query_padding_mask if causal else generate_random_padding_mask( seqlen, batch_size, device, mode="random" ) ( q_unpad_t, k_unpad_t, v_unpad_t, _qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q_padded, k_padded, v_padded, _qv_padded, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) out_ref, _ = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, ) out_pt, _ = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, upcast=False, reorder_ops=True, ) # Select Q input: packed (unpad) or padded (seqused) if unpad_q: q_in = q_unpad_t.detach().to(dtype).requires_grad_() else: q_in = q.detach().to(dtype).requires_grad_() # Select KV input: packed (unpad) or padded (seqused) if unpad_kv: k_in = k_unpad_t.detach().to(dtype).requires_grad_() v_in = v_unpad_t.detach().to(dtype).requires_grad_() else: k_in = k.detach().to(dtype).requires_grad_() v_in = v.detach().to(dtype).requires_grad_() out_unpad, lse = flash_attn_varlen_func( q_in, k_in, v_in, cu_seqlens_q=cu_seqlens_q if unpad_q else None, cu_seqlens_k=cu_seqlens_k if unpad_kv else None, max_seqlen_q=seqlen, max_seqlen_k=seqlen, seqused_q=seqused_q if not unpad_q else None, seqused_k=seqused_k if not unpad_kv else None, causal=causal, ) if is_fake_mode(): return # Reshape output to (batch, seqlen, nheads, d) for comparison out = output_pad_fn(out_unpad) if unpad_q else out_unpad # Mask out padding positions — kernel output at padding positions is undefined q_mask = rearrange(query_padding_mask, "b s -> b s 1 1") out_masked = out.clone().masked_fill_(~q_mask, 0.0) out_ref_masked = out_ref.clone().masked_fill_(~q_mask, 0.0) out_pt_masked = out_pt.clone().masked_fill_(~q_mask, 0.0) fwd_atol = 2 * (out_ref_masked + 0.3 - 0.3 - out_ref_masked).abs().max().item() assert (out_masked - out_ref_masked).abs().max().item() <= 2 * (out_pt_masked - out_ref_masked).abs().max().item() + fwd_atol # Backward (original test skips all SM90 varlen backward) can_bwd = d <= 128 and not IS_SM90 if not can_bwd: return g = torch.randn_like(out_unpad) dq_in, dk_in, dv_in = torch.autograd.grad(out_unpad, (q_in, k_in, v_in), g) assert dq_in.isfinite().all(), "dq contains non-finite values" assert dk_in.isfinite().all(), "dk contains non-finite values" assert dv_in.isfinite().all(), "dv contains non-finite values" assert dq_in.abs().max().item() > 0, "dq is all zeros" assert dk_in.abs().max().item() > 0, "dk is all zeros" assert dv_in.abs().max().item() > 0, "dv is all zeros" # --------------------------------------------------------------------------- # Combine kernel # --------------------------------------------------------------------------- def attention_combine_ref(out_partial, lse_partial): lse = torch.logsumexp(lse_partial, dim=0) scale = torch.exp(lse_partial - lse) scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) out = (scale.unsqueeze(-1) * out_partial).sum(0) return out, lse @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [32, 256]) @pytest.mark.parametrize("num_splits", [2, 5, 17]) @maybe_fake_tensor_mode(USE_FAKE_TENSOR) def test_flash_attn_combine(num_splits, seqlen, d, dtype): device = "cuda" torch.random.manual_seed(1) batch_size = 3 nheads = 8 # out_partial: (num_splits, batch, seqlen, nheads, d) with stride(-1)==1 # lse_partial: (num_splits, batch, seqlen, nheads) with stride(-2)==1 (seqlen contiguous) out_partial = torch.randn( num_splits, batch_size, seqlen, nheads, d, device=device, dtype=torch.float32, ) lse_partial = torch.randn( num_splits, batch_size, nheads, seqlen, device=device, dtype=torch.float32, ).transpose(-1, -2) lse_partial[num_splits // 2 :, : batch_size // 3] = -float("inf") out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype, return_lse=True) if is_fake_mode(): return out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) out_pt = out_ref.to(dtype) assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) ================================================ FILE: tests/cute/test_flash_attn_race_condition.py ================================================ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math import itertools import os import pytest import torch from einops import rearrange, repeat try: from flash_attn.layers.rotary import apply_rotary_emb except ImportError: apply_rotary_emb = None from flash_attn.cute.testing import ( attention_ref, generate_qkv, generate_random_padding_mask, pad_input, unpad_input, ) from flash_attn.cute.interface import ( flash_attn_func, flash_attn_varlen_func, flash_attn_combine, _flash_attn_bwd, ) DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" IS_SM90 = torch.cuda.get_device_capability()[0] == 9 INCREASED_TRIALS = False # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["gqa"]) # @pytest.mark.parametrize("has_learnable_sink", [False, True]) @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [128, 192]) @pytest.mark.parametrize("d", [64, 128, 192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (4224, 4224), (2000, 4000), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( seqlen_q, seqlen_k, d, causal, local_enum, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype, ): local = local_enum > 0 if local and causal: pytest.skip() is_sm90 = torch.cuda.get_device_capability()[0] == 9 if is_sm90 and d == 192: pytest.xfail("headdim 192 not supported on sm90") device = "cuda" # set seed torch.random.manual_seed(0) torch.cuda.empty_cache() torch.cuda.synchronize() batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) dv_vals = [128] if d == 192 else [d] if dtype == torch.float8_e4m3fn: dv_vals = [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn( batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = q_ref * softcap / 4 q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) v_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) if has_qv: qv_ref = ( torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) if local_enum == 2: window_size = (None, -window_size[1]) elif local_enum == 3: window_size = (-window_size[0], None) if local: print("window size = ", window_size) # window_size = (-1, -1) if not local else (16, 0) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3) ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, None, None, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) # k_extended = repeat(k_ref, "b s h d -> b s (h k) d", k=nheads // nheads_kv) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_extended).float() # # if qv is not None: # # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) # # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # # lse_ref = torch.logsumexp(qk, dim=-1) # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] # pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] # num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] num_splits_vals = [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, k, v, causal=causal, # qv=qv, # q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, softcap=softcap, learnable_sink=learnable_sink, pack_gqa=pack_gqa, num_splits=num_splits, deterministic=deterministic, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * ( out_pt - out_ref ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 and softcap == 0.0 and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and learnable_sink is None # and False ): if IS_SM90 and mha_type != "mha": pytest.xfail("SM90 backward: GQA/MQA has tensor layout issue (qhead_per_kvhead > 1)") if IS_SM90 and local: pytest.xfail("SM90 backward: local attention not supported yet") g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # breakpoint() # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad( out_ref, (q_ref, k_ref, v_ref), g ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dq - dq_ref).abs().max().item() <= rtol * ( dq_pt - dq_ref ).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dk - dk_ref).abs().max().item() <= rtol * ( dk_pt - dk_ref ).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dv - dv_ref).abs().max().item() <= rtol * ( dv_pt - dv_ref ).abs().max().item() + dv_atol num_iters = 10_000 if INCREASED_TRIALS else 1000 for i in range(num_iters): dq2, dk2, dv2, = _flash_attn_bwd( q, k, v, out, g, lse, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], deterministic=True, ) diff_dq = (dq - dq2).abs() max_idx = diff_dq.argmax() print(f"dQ max diff: {diff_dq.max().item()}") print(f" at index {max_idx.item()}: dQ={dq.flatten()[max_idx].item()}, dQ2={dq2.flatten()[max_idx].item()}") diff_dk = (dk - dk2).abs() max_idx = diff_dk.argmax() print(f"dK max diff: {diff_dk.max().item()}") print(f" at index {max_idx.item()}: dK={dk.flatten()[max_idx].item()}, dK2={dk2.flatten()[max_idx].item()}") diff_dv = (dv - dv2).abs() max_idx = diff_dv.argmax() print(f"dV max diff: {diff_dv.max().item()}") print(f" at index {max_idx.item()}: dV={dv.flatten()[max_idx].item()}, dV2={dv2.flatten()[max_idx].item()}") # print(f"dQ max diff with myself: {(dq - dq2).abs().max().item()}") # print(f"dK max diff with myself: {(dk - dk2).abs().max().item()}") # print(f"dV max diff with myself: {(dv - dv2).abs().max().item()}") # print(f"dQ mean diff with myself: {(dq - dq2).abs().mean().item()}") # print(f"dK mean diff with myself: {(dk - dk2).abs().mean().item()}") # print(f"dV mean diff with myself: {(dv - dv2).abs().mean().item()}") assert torch.equal(dq, dq2) assert torch.equal(dk, dk2) assert torch.equal(dv, dv2) print(f"✅ Iteration {i} passed!") # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["gqa"]) # @pytest.mark.parametrize("has_learnable_sink", [False, True]) @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [True]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize("local_enum", [0, 1, 2, 3]) @pytest.mark.parametrize("local_enum", [0, 1]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @pytest.mark.parametrize("add_unused_qkv", [False]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", [64, 128, 192]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1024, 1024), (2048, 2048), ], ) @pytest.mark.parametrize("varlen_mode", ["random", "third", "full"]) # @pytest.mark.parametrize("varlen_mode", ["random"]) @pytest.mark.parametrize( "zero_lengths_q, zero_lengths_k", [ (False, False), (True, False), (False, True), (True, True), ], ) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, add_unused_qkv, causal, local_enum, softcap, deterministic, has_qv, has_learnable_sink, mha_type, dtype, varlen_mode, zero_lengths_q, zero_lengths_k, ): local = local_enum > 0 if local and causal: pytest.skip() is_sm90 = torch.cuda.get_device_capability()[0] == 9 if is_sm90 and local: pytest.xfail("bwd local attention not supported on sm90") if is_sm90 and d == 192: pytest.xfail("headdim 192 not supported on sm90") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) batch_size = 49 if seqlen_q <= 1024 else 7 nheads = 6 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype # dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) # dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d]) dv_vals = [128] if d == 192 else [d] # attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0] attention_chunk_vals = [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): q_ref = torch.randn( batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref ) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. q_ref = (q_ref * softcap / 4).detach().requires_grad_() q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) v_ref = ( torch.randn( batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) .requires_grad_() ) if has_qv: qv_ref = ( torch.randn( batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref ) .to(dtype) .to(dtype_ref) ) else: qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = ( (None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist() ) if local_enum == 2: window_size = (None, window_size[1]) elif local_enum == 3: window_size = (window_size[0], None) if local: print("window size = ", window_size) if has_learnable_sink: learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device) else: learnable_sink = None if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3) ] else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode=varlen_mode, zero_lengths=zero_lengths_q, ) key_padding_mask = generate_random_padding_mask( seqlen_k, batch_size, device, mode=varlen_mode, zero_lengths=zero_lengths_k, ) def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if add_unused: another_mask = generate_random_padding_mask(max_seq_len, bs, device) attn_mask = torch.logical_and(padding_mask, another_mask) unused_mask = torch.logical_xor( torch.logical_or(padding_mask, another_mask), attn_mask ) else: attn_mask = padding_mask unused_mask = None return attn_mask, unused_mask query_padding_mask, query_unused_mask = _gen_unused_masks( query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device ) # query_padding_mask[:] = True # query_unused_mask = None key_padding_mask, key_unused_mask = _gen_unused_masks( key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device ) if causal or local: key_padding_mask = query_padding_mask ( q_unpad, k_unpad, v_unpad, qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q, k, v, qv, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv( q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask, ) print("cu_seqlens_q = ", cu_seqlens_q) print("cu_seqlens_k = ", cu_seqlens_k) q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] out_ref, attn_ref = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, ) out_pt, attn_pt = attention_ref( q_ref, k_ref, v_ref, query_padding_mask, key_padding_mask, causal=causal, qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if query_unused_mask is not None: q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 out_unpad, lse = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, # max_seqlen_k, # seqused_q=seqused_q, # seqused_k=seqused_k, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, causal=causal, # qv=qv_unpad, # q_descale=q_descale, # k_descale=k_descale, v_descale=v_descale, window_size=window_size, # attention_chunk=attention_chunk, learnable_sink=learnable_sink, softcap=softcap, num_splits=1, pack_gqa=False, deterministic=deterministic, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") # breakpoint() # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * ( out_pt - out_ref ).abs().max().item() + fwd_atol if ( dtype != torch.float8_e4m3fn and not has_qv and not dv > 256 and not attention_chunk != 0 and ((dv == d and d <= 128) or (d == 192 and dv == 128)) and not has_learnable_sink and not is_sm90 # and False ): g_unpad = torch.randn_like(out_unpad) # do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( # g_unpad, # q_unpad, # k_unpad, # v_unpad, # out_unpad, # lse, # None, # None, # None, # cu_seqlens_q, # cu_seqlens_k, # None, None, # max_seqlen_q, # max_seqlen_k, # d ** (-0.5), # causal, # window_size[0], window_size[1], # softcap, # deterministic, # 0, # sm_margin # ) dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad ) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) if key_unused_mask is not None: k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") dk.masked_fill_(k_zero_masking, 0.0) dv.masked_fill_(k_zero_masking, 0.0) if query_unused_mask is not None: dq.masked_fill_(q_zero_masking, 0.0) # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") # assert (softmax_d - do_o).abs().max().item() <= 1e-5 # assert dq_accum.abs().max().item() == 0.0 g = output_pad_fn(g_unpad) # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) # P = torch.softmax(qk, -1) # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) dq_ref, dk_ref, dv_ref = torch.autograd.grad( out_ref, (q_ref, k_ref, v_ref), g ) dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dq - dq_ref).abs().max().item() <= rtol * ( dq_pt - dq_ref ).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dk - dk_ref).abs().max().item() <= rtol * ( dk_pt - dk_ref ).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( 0 if softcap == 0 else 3e-4 ) assert (dv - dv_ref).abs().max().item() <= rtol * ( dv_pt - dv_ref ).abs().max().item() + dv_atol num_iters = 10_000 if INCREASED_TRIALS else 1000 for i in range(num_iters): dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd( q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], deterministic=True, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=seqlen_q, max_seqlen_k=seqlen_k, ) diff_dq = (dq_unpad - dq_unpad2).abs() max_idx = diff_dq.argmax() if i % 100 == 0: print(f"dQ max diff: {diff_dq.max().item()}") print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}") diff_dk = (dk_unpad - dk_unpad2).abs() max_idx = diff_dk.argmax() if i % 100 == 0: print(f"dK max diff: {diff_dk.max().item()}") print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}") diff_dv = (dv_unpad - dv_unpad2).abs() max_idx = diff_dv.argmax() if i % 100 == 0: print(f"dV max diff: {diff_dv.max().item()}") print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}") assert torch.equal(dq_unpad, dq_unpad2) assert torch.equal(dk_unpad, dk_unpad2) assert torch.equal(dv_unpad, dv_unpad2) if i % 100 == 0: print(f"✅ Iteration {i} passed!") ================================================ FILE: tests/cute/test_flash_attn_varlen.py ================================================ from typing import Optional import pytest import torch import torch.nn.functional as F from flash_attn.cute import flash_attn_varlen_func @pytest.mark.parametrize("B", [1, 7, 20]) @pytest.mark.parametrize("H", [1, 4, 6]) @pytest.mark.parametrize("D", [64, 128]) @pytest.mark.parametrize("min_seq_len", [1, 32, 128]) @pytest.mark.parametrize("max_seq_len", [8, 64, 2048]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("softmax_scale", [None, 0.1]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) def test_varlen( B, H, D, min_seq_len, max_seq_len, causal, softmax_scale, dtype, mha_type, ): if min_seq_len > max_seq_len: pytest.skip("Skipping min_seq_len > max_seq_len") q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k = generate_varlen_args( batch_size=B, n_heads=H, d_head=D, min_len=min_seq_len, max_len=max_seq_len, mha_type=mha_type, dtype=dtype ) ok = check_varlen_vs_torch_flash( q, k, v, cu_seqlens_q, cu_seqlens_k, total_q=total_q, total_k=total_k, softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, ) assert ok def check_varlen_vs_torch_flash( q, k, v, cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, total_q=None, total_k=None, softmax_scale=None, causal=True, mha_type='mha', softcap=0.0, atol=3e-2, rtol=3e-2, ): assert q.requires_grad and k.requires_grad and v.requires_grad, "Set requires_grad=True on inputs" def clone_like(t): c = t.clone().detach().requires_grad_(True) return c q_fa, k_fa, v_fa = map(clone_like, (q, k, v)) q_t, k_t, v_t = map(clone_like, (q, k, v)) if cu_seqlens_q is not None: cu_seqlens_q_fa = cu_seqlens_q.clone() cu_seqlens_q_t = cu_seqlens_q.clone() else: cu_seqlens_q_fa = None cu_seqlens_q_t = None if cu_seqlens_k is not None: cu_seqlens_k_fa = cu_seqlens_k.clone() cu_seqlens_k_t = cu_seqlens_k.clone() else: cu_seqlens_k_fa = None cu_seqlens_k_t = None out_fa, lse_fa = flash_attn_varlen_func( q_fa, k_fa, v_fa, cu_seqlens_q=cu_seqlens_q_fa, cu_seqlens_k=cu_seqlens_k_fa, seqused_q=seqused_q, seqused_k=seqused_k, softmax_scale=(1.0 / q.shape[-1]**0.5) if softmax_scale is None else softmax_scale, causal=causal, window_size=(None, None), learnable_sink=None, softcap=softcap, pack_gqa=None, ) out_t = torch_flash_ref( q_t, k_t, v_t, cu_seqlens_q=cu_seqlens_q_t, cu_seqlens_k=cu_seqlens_k_t, seqused_q=seqused_q, seqused_k=seqused_k, total_q=total_q, total_k=total_k, softmax_scale=softmax_scale, causal=causal, mha_type=mha_type, ) ok_fwd = torch.allclose(out_fa.float(), out_t.float(), atol=atol, rtol=rtol) if not ok_fwd: return False # Use the same upstream gradient to compare backward paths grad_out = torch.randn_like(out_fa) grad_fa = clone_like(grad_out) grad_t = clone_like(grad_out) # Cute bwd out_fa.backward(grad_fa, retain_graph=False) dq_fa, dk_fa, dv_fa = q_fa.grad, k_fa.grad, v_fa.grad # Ref bwd out_t.backward(grad_t, retain_graph=False) dq_t, dk_t, dv_t = q_t.grad, k_t.grad, v_t.grad # mean_ok_q = _stats("dQ", dq_fa, dq_t, atol=atol, rtol=rtol) # mean_ok_k = _stats("dK", dk_fa, dk_t, atol=atol, rtol=rtol) # mean_ok_v = _stats("dV", dv_fa, dv_t, atol=atol, rtol=rtol) # return mean_ok_q and mean_ok_k and mean_ok_v ok_q = torch.allclose(dq_fa.float(), dq_t.float(), atol=atol, rtol=rtol) ok_k = torch.allclose(dk_fa.float(), dk_t.float(), atol=atol, rtol=rtol) ok_v = torch.allclose(dv_fa.float(), dv_t.float(), atol=atol, rtol=rtol) # print(f"Close? dQ={ok_q}, dK={ok_k}, dV={ok_v}") return ok_q and ok_k and ok_v def generate_varlen_args( batch_size=8, n_heads=16, d_head=128, min_len=32, max_len=64, mha_type="mha", dtype = torch.bfloat16, ): torch.manual_seed(0) device = "cuda" assert mha_type in ["mha", "mqa", "gqa"] lens_q = torch.randint(low=min_len, high=max_len + 1, size=(batch_size,)) lens_k = lens_q.clone() cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32), lens_q.cumsum(0)]) cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32), lens_k.cumsum(0)]) total_q = cu_seqlens_q[-1] total_k = cu_seqlens_k[-1] cu_seqlens_q = cu_seqlens_q.contiguous().to(dtype=torch.int32, device=device) cu_seqlens_k = cu_seqlens_k.contiguous().to(dtype=torch.int32, device=device) if mha_type == "gqa": H = 3 * n_heads H_kv = n_heads elif mha_type == "mha": H = H_kv = n_heads else: # MQA H = n_heads H_kv = 1 d_head_v = d_head q = torch.randn(total_q, H, d_head, device=device, dtype=dtype, requires_grad=True) k = torch.randn(total_k, H_kv, d_head, device=device, dtype=dtype, requires_grad=True) v = torch.randn(total_k, H_kv, d_head_v, device=device, dtype=dtype, requires_grad=True) return q, k, v, cu_seqlens_q, cu_seqlens_k, total_q, total_k # Simple for loop over batch dim implementation def torch_flash_ref( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor = None, cu_seqlens_k: torch.Tensor = None, total_q: int = 0, total_k: int = 0, softmax_scale: Optional[float] = None, causal: bool = False, **kwargs ): """ q: (total_q, H, d) if cu_seqlens_q is not None, otherwise (B, L, H, d) k: (total_k, H_kv, d) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d) v: (total_k, H_kv, d_v) if cu_seqlens_k is not None, otherwise (B, L, H_kv, d_v) cu_seqlens_q: (B+1,) int32, cumulative cu_seqlens_k: (B+1,) int32, cumulative seqused_q: (B+1,) int32 seqused_k: (B+1,) int32 Returns: out packed like q: (total_q, H, d_v) """ if cu_seqlens_q is not None: assert cu_seqlens_q.dim() == 1 assert total_q == q.shape[0] assert q.dim() == 3 H = q.shape[1] B = cu_seqlens_q.shape[0] - 1 else: assert q.dim() == 4 H = q.shape[2] B = q.shape[0] if cu_seqlens_k is not None: assert cu_seqlens_k.dim() == 1 assert total_k == k.shape[0] == v.shape[0] assert k.dim() == v.dim() == 3 H_kv = k.shape[1] B_kv = cu_seqlens_k.shape[0] - 1 else: assert k.dim() == v.dim() == 4 assert k.shape[0] == v.shape[0] H_kv = k.shape[2] B_kv = k.shape[0] d = q.shape[-1] d_v = v.shape[-1] assert H_kv == v.shape[-2] assert d == k.shape[-1] assert B == B_kv assert q.device == k.device == v.device assert q.is_floating_point() and k.is_floating_point() and v.is_floating_point() device = q.device dtype = q.dtype hcseq_q = cu_seqlens_q.to(device='cpu') hcseq_k = cu_seqlens_k.to(device='cpu') outs = [] for b in range(B): if hcseq_q is not None: q_start, q_end = int(hcseq_q[b]), int(hcseq_q[b+1]) qb = q[q_start:q_end] else: qb = q[b] if hcseq_k is not None: k_start, k_end = int(hcseq_k[b]), int(hcseq_k[b+1]) kb = k[k_start:k_end] vb = v[k_start:k_end] else: kb = k[b] vb = v[b] qb = qb.permute(1, 0, 2).unsqueeze(0) kb = kb.permute(1, 0, 2).unsqueeze(0) vb = vb.permute(1, 0, 2).unsqueeze(0) ob = F.scaled_dot_product_attention( qb, kb, vb, attn_mask=None, dropout_p=0.0, is_causal=causal, scale=softmax_scale, enable_gqa=H_kv!=H ) ob = ob.squeeze(0).permute(1, 0, 2).contiguous() outs.append(ob) if cu_seqlens_q is not None: out = torch.cat(outs, dim=0).to(device=device, dtype=dtype) else: out = torch.stack(outs, dim=0).to(device=device, dtype=dtype) return out @torch.no_grad() def _stats(name, a, b, atol, rtol): diff = (a - b).float() mean_abs = diff.abs().mean().item() mean_rel = (diff.abs().mean() / b.abs().clamp_min(1e-6).mean().item()) print(f"{name}: mean_abs={mean_abs:.4e}, mean_rel={mean_rel:.4e}, sum_fa={a.sum()}, sum_ref={b.sum()}") return mean_abs < atol and mean_rel < rtol ================================================ FILE: tests/cute/test_mask_mod.py ================================================ # mask mod test script # REFACTORED to use _flash_attn_fwd as the kernel entrypoint # # Test Organization: # - test_static_masks: Fast tests for masks that don't need per-seqlen compilation # (identity, document, block_diagonal, etc.) with comprehensive seqlen coverage # - test_parameterized_masks: Slower tests for masks that require recompilation per # seqlen pair (causal, block_causal, sliding_window) with reduced seqlen coverage # # Usage: # pytest test_mask_mod.py::test_static_masks # Run only fast tests # pytest test_mask_mod.py::test_parameterized_masks # Run only slow tests # pytest test_mask_mod.py # Run all tests import math from unittest import mock import pytest import torch import cutlass import cutlass.cute as cute from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, fast_sampling, normalize_block_sparse_config, ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils from mask_mod_definitions import get_mask_pair, random_doc_id_tensor COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @pytest.fixture(autouse=True) def reset_torch_state(): """Reset torch dynamo/compile state between tests to avoid state pollution.""" torch._dynamo.reset() torch.cuda.empty_cache() yield torch._dynamo.reset() torch.cuda.empty_cache() def create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ): device = "cuda" q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype) k = torch.randn( batch_size, seqlen_k, nheads_kv, headdim, device=device, dtype=dtype ) v = torch.randn( batch_size, seqlen_k, nheads_kv, headdim_v, device=device, dtype=dtype ) out = torch.empty( batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype ) lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) return { "q": q, "k": k, "v": v, "out": out, "lse": lse, } def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, int] | None = None): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape _, seqlen_k, nheads_kv, _ = tensors["k"].shape q = tensors["q"].transpose(1, 2) k = tensors["k"].transpose(1, 2) v = tensors["v"].transpose(1, 2) if nheads != nheads_kv: repeat_factor = nheads // nheads_kv k = k.repeat_interleave(repeat_factor, dim=1) v = v.repeat_interleave(repeat_factor, dim=1) scale = 1.0 / math.sqrt(headdim) # Handle identity (no masking) case if mask_mod_flex is None: out_ref = F.scaled_dot_product_attention(q, k, v, scale=scale) return out_ref.transpose(1, 2).contiguous() block_mask_kwargs = {} if block_size is not None: block_mask_kwargs["BLOCK_SIZE"] = block_size block_mask = create_block_mask( mask_mod_flex, B=batch_size, H=nheads, Q_LEN=seqlen_q, KV_LEN=seqlen_k, device=q.device, **block_mask_kwargs, ) out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True) return out_ref.transpose(1, 2).contiguous() def get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_block: int): @fast_sampling @cute.jit def _cute_coarse_block_mask( batch: cute.TensorSSA, head: cute.TensorSSA, m_idx: cute.TensorSSA, n_idx: cute.TensorSSA, seqlen_info, aux_tensors, ) -> cute.TensorSSA: sparse_tile_m_ssa = utils.scalar_to_ssa(sparse_tile_m, cutlass.Int32) tile_n_ssa = utils.scalar_to_ssa(tile_n, cutlass.Int32) q_block = m_idx // sparse_tile_m_ssa n_block = n_idx // tile_n_ssa zero = utils.scalar_to_ssa(0, cutlass.Int32) one = utils.scalar_to_ssa(1, cutlass.Int32) last = utils.scalar_to_ssa(last_block, cutlass.Int32) return ((q_block == zero) & (n_block == zero)) | ((q_block == one) & (n_block == last)) def _flex_coarse_block_mask(b, h, q_idx, kv_idx): q_block = q_idx // sparse_tile_m n_block = kv_idx // tile_n return ((q_block == 0) & (n_block == 0)) | ((q_block == 1) & (n_block == last_block)) return _cute_coarse_block_mask, _flex_coarse_block_mask SEQLEN_PAIRS_COMPREHENSIVE = [ (1, 1), (64, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), (4224, 4224), ] SEQLEN_PAIRS_SMOKE = [ (128, 128), (256, 256), (113, 203), (1024, 1024), (128, 8192) ] def _run_mask_test( seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, mask_name, window_size, window_left, window_right, tile_m, tile_n, use_block_sparsity, needs_backward=False, ): torch.manual_seed(42) if mask_name == "sliding_window": assert window_size is not None, ( "window_size must be specified for sliding_window" ) if seqlen_q > seqlen_k: pytest.skip( f"seqlen_q={seqlen_q} > seqlen_k={seqlen_k} not supported for sliding_window" ) # Determine nheads_kv based on mode if kv_mode == "mha": nheads_kv = nheads pack_gqa = False elif kv_mode == "gqa": if COMPUTE_CAPABILITY < 9: pytest.xfail("pack_gqa requires SM90+") nheads_kv = nheads // 4 pack_gqa = True elif kv_mode == "mqa": nheads_kv = 1 pack_gqa = False else: raise ValueError(f"Unknown kv_mode: {kv_mode}") batch_size = 1 headdim_v = headdim aux_tensors_arg = None mask_mod_cute, mask_mod_flex = get_mask_pair( mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size ) if mask_name == "document": doc_len = max(seqlen_q, seqlen_k) doc_ids = random_doc_id_tensor(nheads, batch_size, doc_len, device="cuda").to( dtype=torch.int32, device="cuda" ) original_flex_mask = mask_mod_flex def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) aux_tensors_arg = [doc_ids] elif mask_name == "ima": bias_threshold = (seqlen_k // 4) * 3 bias = torch.full((seqlen_k,), bias_threshold, dtype=torch.int32, device="cuda") original_flex_mask = mask_mod_flex def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): return original_flex_mask(b, h, q_idx, kv_idx, bias) aux_tensors_arg = [bias] causal = False if causal and seqlen_k < seqlen_q: pytest.skip("causal masking requires seqlen_k >= seqlen_q") tensors = create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ) # SM100 uses sparse_tile_m = 2*tile_m to match forward q_stage=2 pipelining if COMPUTE_CAPABILITY == 10: sparse_tile_m = 2 * tile_m else: sparse_tile_m = tile_m block_mask_nheads = 1 if pack_gqa else nheads bm = create_block_mask( mask_mod_flex, batch_size, block_mask_nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) ( _seq_q, _seq_k, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, ) = bm.as_tuple() # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling. sparse_tile_m_bwd = sparse_tile_m tile_n_bwd = tile_n if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128): bm_bwd = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(128, 128), ) ( _seq_q, _seq_k, _kv_mask_cnt, _kv_mask_idx, _full_kv_cnt, _full_kv_idx, q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, ) = bm_bwd.as_tuple() sparse_tile_m_bwd = 128 tile_n_bwd = 128 softmax_scale = 1.0 / math.sqrt(headdim) block_sparse_mask_fwd = ( BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) if use_block_sparsity else None ) # Backward uses Q-direction (transposed) sparse tensors block_sparse_mask_bwd = ( BlockSparseTensorsTorch( mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, block_size=(sparse_tile_m_bwd, tile_n_bwd), ) if use_block_sparsity else None ) out_tuple = _flash_attn_fwd( q=tensors["q"], k=tensors["k"], v=tensors["v"], out=tensors["out"], lse=tensors["lse"], cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, page_table=None, softmax_scale=softmax_scale, causal=causal, softcap=None, window_size_left=window_left, window_size_right=window_right, learnable_sink=None, tile_mn=(tile_m, tile_n), pack_gqa=pack_gqa, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, aux_tensors=aux_tensors_arg, ) out_cute = out_tuple[0] lse_cute = out_tuple[1] tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() } block_size = (tile_m, tile_n) out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, mask_mod_flex, block_size) out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, block_size) out_pt = out_ref.clone() # Check for invalid values assert out_cute.shape == out_ref_fp32.shape == out_ref.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() # Compute numerical tolerance (matching flash attention tests) fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 ref_error = (out_ref - out_ref_fp32).abs().max().item() pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() mask_desc = f"mask_mod={mask_name}" if mask_name == "sliding_window" and window_size is not None: mask_desc += f"(w={window_size})" print( f"\n{mask_desc} @ Q={seqlen_q}, K={seqlen_k}, H={nheads}/{nheads_kv} ({kv_mode}), " f"D={headdim}, M={tile_m}, N={tile_n}" ) print(" Reference implementation: FlexAttention") print(f" Reference vs FP32: {ref_error:.2e}") print(f" PyTorch vs FP32: {pt_error:.2e}") print(f" Kernel vs FP32: {cute_error:.2e}") print(f" Tolerance: rtol={rtol} * {pt_error:.2e} + {fwd_atol:.2e}") print(f" Error ratio: {cute_error / max(pt_error, 1e-10):.2f}") # Debug: show some sample values if error is large if cute_error > 1e-2: print(f" DEBUG: Sample kernel output: {out_cute[0, 0, 0, :5]}") print(f" DEBUG: Sample reference output: {out_ref_fp32[0, 0, 0, :5]}") print(f" DEBUG: Max diff location: {(out_cute - out_ref_fp32).abs().argmax()}") max_diff_idx = (out_cute - out_ref_fp32).abs().argmax() max_diff_coords = torch.unravel_index(max_diff_idx, out_cute.shape) print(f" DEBUG: Max diff at coords: {max_diff_coords}") print(f" DEBUG: Kernel value: {out_cute[max_diff_coords]:.6f}") print(f" DEBUG: Reference value: {out_ref_fp32[max_diff_coords]:.6f}") # Use the same assertion logic as FlashAttention tests assert cute_error <= rtol * pt_error + fwd_atol, ( f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) if needs_backward: q = tensors["q"] k = tensors["k"] v = tensors["v"] # Create grad_out once and reuse grad_out = torch.randn_like(out_cute) # Create block_mask for flex reference flex_block_mask = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, aux_tensors=aux_tensors_arg, ) _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, flex_block_mask, grad_out, dtype=torch.float32 ) _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( q, k, v, flex_block_mask, grad_out ) # Check for invalid values assert not torch.isnan(dq_cute).any(), "dQ contains NaN" assert not torch.isnan(dk_cute).any(), "dK contains NaN" assert not torch.isnan(dv_cute).any(), "dV contains NaN" bwd_rtol = 2 min_seqlen = min(seqlen_q, seqlen_k) bwd_atol_floor = 1e-5 if min_seqlen >= 64 else 3e-5 dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) dq_ref = dq_ref_fp32.to(dtype) dk_ref = dk_ref_fp32.to(dtype) dv_ref = dv_ref_fp32.to(dtype) pt_dq_err = (dq_pt - dq_ref).abs().max().item() pt_dk_err = (dk_pt - dk_ref).abs().max().item() pt_dv_err = (dv_pt - dv_ref).abs().max().item() cute_dq_err = (dq_cute - dq_ref).abs().max().item() cute_dk_err = (dk_cute - dk_ref).abs().max().item() cute_dv_err = (dv_cute - dv_ref).abs().max().item() print(" Backward comparison:") print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" def test_mask_mod_ima_partial_block(): _run_mask_test( seqlen_q=257, seqlen_k=257, nheads=1, kv_mode="mha", headdim=128, dtype=torch.bfloat16, mask_name="ima", window_size=None, window_left=None, window_right=None, tile_m=128, tile_n=128, use_block_sparsity=True, needs_backward=True, ) # Q boundary seqlens: NOT multiples of tile_m (128) # These exercise the fix for is_full_block tiles not masking OOB Q rows in backward Q_BOUNDARY_SEQLEN_PAIRS = [ (200, 200), # Last m_block: rows 128-199 valid, 200-255 should be masked (300, 300), # Last m_block: rows 256-299 valid, 300-383 should be masked (129, 129), # Just 1 element into second tile (255, 255), # Just 1 element short of 2 full tiles (500, 512), # Q boundary only (K aligned) (512, 500), # K boundary only (Q aligned) (333, 444), # Both non-aligned ] @pytest.mark.parametrize("seqlen_q,seqlen_k", Q_BOUNDARY_SEQLEN_PAIRS) @pytest.mark.parametrize("mask_name", ["block_diagonal", "document"]) def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): """Test Q boundary masking for block-sparse backward pass. This test specifically exercises the fix for the bug where Q rows beyond seqlen_q were not masked in backward pass for is_full_block=True tiles. The bug occurred because: - In forward, apply_mask_sm100 always checks both Q and K bounds - In backward, apply_mask_sm100_transposed with is_full_block=True only checked K bounds - Result: partial last m_blocks had unmasked garbage Q rows contributing to gradients Key conditions: - seqlen_q NOT a multiple of tile_m (128): creates partial last m_block - Block-sparse with mask_mod: exercises is_full_block=True path - Backward pass: where the bug manifested """ _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, nheads=4, kv_mode="mha", headdim=128, dtype=torch.bfloat16, mask_name=mask_name, window_size=None, window_left=None, window_right=None, tile_m=128, tile_n=128, use_block_sparsity=True, needs_backward=True, ) @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Test uses SM100 block mask conventions (2*tile_m)") def test_single_doc_bwd_minimal(): """Minimal test to isolate single-document backward pass bug. This test uses batch=1, nheads=1, and a single document (all same doc_id) to make debugging easier. The bug manifests as large numerical errors in dQ, dK, dV when blocks are classified as "full blocks" due to the mask returning True for all positions. Run with: pytest tests/cute/test_mask_mod.py::test_single_doc_bwd_minimal -v -s """ import random random.seed(42) torch.manual_seed(42) seqlen_q = 384 seqlen_k = 300 batch_size = 1 nheads = 1 headdim = 128 tile_m = 128 tile_n = 128 dtype = torch.bfloat16 # Create single-document doc_ids (all same doc_id = 0) doc_ids = torch.zeros(batch_size, nheads, max(seqlen_q, seqlen_k), dtype=torch.int32, device="cuda") from mask_mod_definitions import get_mask_pair mask_mod_cute, mask_mod_flex = get_mask_pair("document", seqlen_q=seqlen_q, seqlen_k=seqlen_k) original_flex_mask = mask_mod_flex def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) aux_tensors_arg = [doc_ids] # Create tensors q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) out = torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) lse = torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32) sparse_tile_m = 2 * tile_m bm = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) ( _seq_q, _seq_k, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, ) = bm.as_tuple() block_sparse_mask_fwd = BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) block_sparse_mask_bwd = BlockSparseTensorsTorch( mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, block_size=(sparse_tile_m, tile_n), ) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, page_table=None, causal=False, softcap=None, window_size_left=-1, window_size_right=-1, tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, aux_tensors=aux_tensors_arg, ) out_cute = out_tuple[0] lse_cute = out_tuple[1] # Backward pass grad_out = torch.randn_like(out_cute) dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, aux_tensors=aux_tensors_arg, ) flex_block_mask = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) out_ref, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd( q, k, v, flex_block_mask, grad_out, dtype=torch.float32 ) # Compare dq_err = (dq_cute - dq_ref.to(dtype)).abs().max().item() dk_err = (dk_cute - dk_ref.to(dtype)).abs().max().item() dv_err = (dv_cute - dv_ref.to(dtype)).abs().max().item() print(f"dQ error: {dq_err:.2e}") print(f"dK error: {dk_err:.2e}") print(f"dV error: {dv_err:.2e}") # Assert gradients are correct (this will fail, demonstrating the bug) assert dq_err < 0.05, f"dQ error too large: {dq_err:.2e}" assert dk_err < 0.05, f"dK error too large: {dk_err:.2e}" assert dv_err < 0.05, f"dV error too large: {dv_err:.2e}" @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_COMPREHENSIVE) @pytest.mark.parametrize("nheads", [16]) @pytest.mark.parametrize("kv_mode", ["mha", "gqa", "mqa"]) @pytest.mark.parametrize("headdim", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("use_block_sparsity", [True, False]) @pytest.mark.parametrize( "mask_name", ["block_diagonal", "mini_causal"], ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112)]) def test_static_masks( seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, tile_m, tile_n ): """Test static masks that don't require recompilation per seqlen pair. Known good masks: - block_diagonal: Masks by 64-element diagonal blocks - mini_causal: Local causal within 128-element tiles """ if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, nheads=nheads, kv_mode=kv_mode, headdim=headdim, dtype=dtype, mask_name=mask_name, window_size=None, window_left=None, window_right=None, tile_m=tile_m, tile_n=tile_n, use_block_sparsity=use_block_sparsity, needs_backward=True, ) @pytest.mark.parametrize("seqlen_q,seqlen_k", SEQLEN_PAIRS_SMOKE) @pytest.mark.parametrize("nheads", [16]) @pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) @pytest.mark.parametrize("headdim", [128]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("use_block_sparsity", [True, False]) @pytest.mark.parametrize( "mask_name,window_size", [ ("causal", None), ("block_causal", None), ("sliding_window", 128), ("sliding_window", 256), ("sliding_window", 512), ("document", None), ], ) @pytest.mark.parametrize("tile_m,tile_n", [(128, 128), (128, 112), (64, 128)]) def test_parameterized_masks( seqlen_q, seqlen_k, nheads, kv_mode, headdim, dtype, use_block_sparsity, mask_name, window_size, tile_m, tile_n ): """Test parameterized masks that require recompilation per seqlen pair. Uses fewer seqlen combinations to reduce test time. Masks tested: - causal, block_causal: Require offset = seqlen_k - seqlen_q - sliding_window: Requires window size and offset parameters - document: Slower to check """ if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, nheads=nheads, kv_mode=kv_mode, headdim=headdim, dtype=dtype, mask_name=mask_name, window_size=window_size, window_left=None, window_right=None, tile_m=tile_m, tile_n=tile_n, use_block_sparsity=use_block_sparsity, needs_backward=True, ) def test_sm100_block_sparse_sink_all_masked(): """Block-sparse regression for the sink path""" if torch.cuda.get_device_capability()[0] != 10: pytest.skip("SM100-only test") device = "cuda" dtype = torch.bfloat16 batch_size = 1 seqlen_q = 256 seqlen_k = 128 nheads = 8 headdim = 128 q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device) zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device) zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device) sparse = BlockSparseTensorsTorch( mask_block_cnt=zero_cnt, mask_block_idx=zero_idx, full_block_cnt=zero_cnt, full_block_idx=zero_idx, block_size=(256, 128), ) softmax_scale = 1.0 / math.sqrt(headdim) _, lse = _flash_attn_fwd( q=q, k=k, v=v, softmax_scale=softmax_scale, causal=False, window_size_left=None, window_size_right=None, learnable_sink=learnable_sink, tile_mn=(128, 128), num_threads=384, pack_gqa=False, block_sparse_tensors=sparse, return_lse=True, ) # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit. expected = learnable_sink.float()[None, :, None].expand_as(lse) assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="SM100-only test") def test_sm100_block_sparse_q_stage1(): from flash_attn.cute import flash_fwd_sm100 from flash_attn.cute.interface import _flash_attn_fwd observed = {} original_init = flash_fwd_sm100.FlashAttentionForwardSm100.__init__ def wrapped_init(self, *args, **kwargs): observed["q_stage"] = kwargs.get("q_stage") return original_init(self, *args, **kwargs) with mock.patch.object( flash_fwd_sm100.FlashAttentionForwardSm100, "__init__", wrapped_init, ): compile_cache = _flash_attn_fwd.compile_cache _flash_attn_fwd.compile_cache = get_jit_cache("test_mask_mod.fwd") try: _run_mask_test( seqlen_q=128, seqlen_k=128, nheads=4, kv_mode="mha", headdim=128, dtype=torch.bfloat16, mask_name="block_diagonal", window_size=None, window_left=None, window_right=None, tile_m=128, tile_n=128, use_block_sparsity=True, needs_backward=False, ) finally: _flash_attn_fwd.compile_cache.clear() _flash_attn_fwd.compile_cache = compile_cache assert observed.get("q_stage") == 1 @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="SM100-only test") def test_sm100_block_sparse_coarse_blocks(): torch.manual_seed(42) seqlen_q = 512 seqlen_k = 512 nheads = 4 headdim = 128 dtype = torch.bfloat16 tile_m = 128 tile_n = 128 sparse_tile_m = 512 batch_size = 1 mask_mod_cute, mask_mod_flex = get_mask_pair( "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None ) tensors = create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype ) bm = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) ( _seq_q, _seq_k, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_, ) = bm.as_tuple() block_sparse_mask_fwd = BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) out_cute, _ = _flash_attn_fwd( q=tensors["q"], k=tensors["k"], v=tensors["v"], out=tensors["out"], lse=tensors["lse"], cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, page_table=None, softmax_scale=1.0 / math.sqrt(headdim), causal=False, softcap=None, window_size_left=None, window_size_right=None, learnable_sink=None, tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, ) tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() } out_ref_fp32 = compute_reference_flex_attn( tensors_fp32, mask_mod_flex, (sparse_tile_m, tile_n) ) out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, (sparse_tile_m, tile_n)) assert out_cute.shape == out_ref_fp32.shape == out_ref.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 pt_error = (out_ref - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() assert cute_error <= rtol * pt_error + fwd_atol, ( f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="SM100-only test") def test_sm100_block_sparse_coarse_blocks_mismatch(): torch.manual_seed(0) seqlen_q = 1024 seqlen_k = 512 nheads = 2 headdim = 128 dtype = torch.bfloat16 tile_m = 128 tile_n = 128 sparse_tile_m = 512 batch_size = 1 mask_mod_cute, mask_mod_flex = get_coarse_block_mask_pair( sparse_tile_m, tile_n, last_block=3 ) tensors = create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype ) bm = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) ( _seq_q, _seq_k, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_, ) = bm.as_tuple() block_sparse_mask_fwd = BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) observed = {} original_normalize = normalize_block_sparse_config def wrapped_normalize(*args, **kwargs): normalized, pattern, q_subtile_factor = original_normalize(*args, **kwargs) observed["q_subtile_factor"] = q_subtile_factor return normalized, pattern, q_subtile_factor with mock.patch("flash_attn.cute.interface.normalize_block_sparse_config", wrapped_normalize): out_cute, _ = _flash_attn_fwd( q=tensors["q"], k=tensors["k"], v=tensors["v"], out=tensors["out"], lse=tensors["lse"], cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, page_table=None, softmax_scale=1.0 / math.sqrt(headdim), causal=False, softcap=None, window_size_left=None, window_size_right=None, learnable_sink=None, tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, ) assert observed.get("q_subtile_factor") == 2 tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() } out_ref_fp32 = compute_reference_flex_attn( tensors_fp32, mask_mod_flex, (sparse_tile_m, tile_n) ) out_ref = compute_reference_flex_attn(tensors, mask_mod_flex, (sparse_tile_m, tile_n)) assert out_cute.shape == out_ref_fp32.shape == out_ref.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 pt_error = (out_ref - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() assert cute_error <= rtol * pt_error + fwd_atol, ( f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) # ============================================================================= # Backward Helper Functions # ============================================================================= def run_cute_mask_bwd( q, k, v, out, lse, grad_out, mask_mod_cute, block_sparse_mask_bwd=None, tile_m=128, tile_n=128, aux_tensors=None, ): """Run flash attention backward with mask_mod. Args: q, k, v: Input tensors in BSHD format out: Forward output tensor lse: Log-sum-exp from forward pass grad_out: Gradient of output mask_mod_cute: CuTE mask modification function block_sparse_mask_bwd: Block sparse tensors for backward pass tile_m, tile_n: Tile sizes aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) Returns (dq, dk, dv) all in BSHD format. """ dq, dk, dv = _flash_attn_bwd( q=q, k=k, v=v, out=out, dout=grad_out, lse=lse, causal=False, m_block_size=tile_m, n_block_size=tile_n, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_bwd, aux_tensors=aux_tensors, ) return dq, dk, dv def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): """Run flex_attention forward + backward for reference. Args: q, k, v: Input tensors in BSHD format block_mask: Pre-created block mask for flex_attention grad_out: Gradient of output in BSHD format dtype: Optional dtype to cast inputs to (e.g., torch.float32 for reference) Returns (out, dq, dk, dv) all in BSHD format. """ # Transpose to BHSD for flex_attention if dtype is not None: q_ref = q.transpose(1, 2).to(dtype).requires_grad_(True) k_ref = k.transpose(1, 2).to(dtype).requires_grad_(True) v_ref = v.transpose(1, 2).to(dtype).requires_grad_(True) grad_out_ref = grad_out.transpose(1, 2).to(dtype) else: q_ref = q.transpose(1, 2).requires_grad_(True) k_ref = k.transpose(1, 2).requires_grad_(True) v_ref = v.transpose(1, 2).requires_grad_(True) grad_out_ref = grad_out.transpose(1, 2) # Use flex_attention directly without torch.compile for backward tests # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32) out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True) dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) # Transpose back to BSHD return ( out_ref.transpose(1, 2), dq_ref.transpose(1, 2), dk_ref.transpose(1, 2), dv_ref.transpose(1, 2), ) def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): if COMPUTE_CAPABILITY != 9: pytest.skip("SM90-only test") batch_size = 1 seqlen_q = 256 seqlen_k = 256 nheads = 4 nheads_kv = nheads headdim = 128 dtype = torch.bfloat16 tile_m = 80 tile_n = 128 tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) mask_mod_cute, mask_mod_flex = get_mask_pair("block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) bm = create_block_mask( mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) ( _seq_q, _seq_k, _kv_mask_cnt, _kv_mask_idx, _full_kv_cnt, _full_kv_idx, q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, ) = bm.as_tuple() block_sparse_mask_bwd = BlockSparseTensorsTorch( mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, block_size=(tile_m, tile_n), ) softmax_scale = 1.0 / math.sqrt(headdim) out = torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) lse = torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32) grad_out = torch.randn_like(out) with pytest.raises( ValueError, match=r"Block sparsity expects sparse_block_size_q=128 for subtile_factor=2\.", ): _flash_attn_bwd( q=tensors["q"], k=tensors["k"], v=tensors["v"], out=out, dout=grad_out, lse=lse, softmax_scale=softmax_scale, causal=False, m_block_size=tile_m, n_block_size=tile_n, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_bwd, ) def test_gqa_block_sparse_broadcast_pattern_recompilation(): """Test that different block sparse broadcast patterns trigger recompilation. This is a regression test for a bug where: 1. First call with block_mask H=1 (broadcasts across all query heads) 2. Second call with block_mask H=nheads (no broadcast) 3. Second call incorrectly reused cached kernel from first call The fix adds block_sparse_broadcast_pattern to the compile key so that kernels are recompiled when broadcast patterns change. CuTe's mark_layout_dynamic() keeps stride=0 as static, so different broadcast patterns require different compiled kernels. """ torch.manual_seed(42) batch_size = 2 nheads = 8 nheads_kv = 2 seqlen = 257 headdim = 64 dtype = torch.bfloat16 tile_m = 128 tile_n = 128 sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m def causal_mask(b, h, q, kv): return q >= kv mask_mod_cute, _ = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen) tensors = create_tensors(batch_size, seqlen, seqlen, nheads, nheads_kv, headdim, headdim, dtype) q, k, v = tensors["q"], tensors["k"], tensors["v"] grad_out = torch.randn_like(tensors["out"]) softmax_scale = 1.0 / math.sqrt(headdim) def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bm = create_block_mask( causal_mask, batch_size, block_mask_nheads, seqlen, seqlen, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) ( _seq_q, _seq_k, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, ) = bm.as_tuple() block_sparse_fwd = BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) block_sparse_bwd = BlockSparseTensorsTorch( mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, block_size=(sparse_tile_m, tile_n), ) out = torch.empty_like(tensors["out"]) lse = torch.empty_like(tensors["lse"]) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=False, window_size_left=-1, window_size_right=-1, tile_mn=(tile_m, tile_n), pack_gqa=False, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd, return_lse=True, ) out_cute, lse_cute = out_tuple[0], out_tuple[1] dq, dk, dv = run_cute_mask_bwd( q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, block_sparse_mask_bwd=block_sparse_bwd, tile_m=tile_m, tile_n=tile_n, ) return dq, dk, dv flex_block_mask = create_block_mask( causal_mask, batch_size, nheads, seqlen, seqlen, device="cuda", BLOCK_SIZE=(tile_m, tile_n), ) _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) dq_ref, dk_ref, dv_ref = dq_ref.to(dtype), dk_ref.to(dtype), dv_ref.to(dtype) dq_broadcast, dk_broadcast, dv_broadcast = run_with_block_mask_nheads(1) dq_no_broadcast, dk_no_broadcast, dv_no_broadcast = run_with_block_mask_nheads(nheads) err_broadcast_dq = (dq_broadcast - dq_ref).abs().max().item() err_no_broadcast_dq = (dq_no_broadcast - dq_ref).abs().max().item() print("\nGQA block sparse broadcast pattern test:") print(f" dQ error (H=1 broadcast): {err_broadcast_dq:.2e}") print(f" dQ error (H={nheads} no broadcast): {err_no_broadcast_dq:.2e}") assert err_broadcast_dq < 0.1, f"Broadcast dQ error too large: {err_broadcast_dq:.2e}" assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}" def test_gqa_expand_stride_zero_bug(): """Test that GQA with expand()-created K/V tensors works correctly. This is a regression test for bugs with expand()-created tensors: Forward bug: cute.assume() fails when tensor strides are Python int 0 (from expand()) instead of MLIR values. Error: AttributeError: 'int' object has no attribute 'type' Backward bug: mark_layout_dynamic fails with expanded tensors. Error: RuntimeError: Expected strides[leading_dim] == 1, but got N. Trigger: expand() + transpose() creates stride=0 dimensions (GQA pattern). """ torch.manual_seed(42) batch_size = 1 seqlen = 2048 headdim = 128 n_heads = 4 n_kv_heads = 1 dtype = torch.bfloat16 device = "cuda" q = torch.randn(batch_size, seqlen, n_heads, headdim, device=device, dtype=dtype) k_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) v_orig = torch.randn(batch_size, seqlen, n_kv_heads, headdim, device=device, dtype=dtype) k = k_orig.expand(batch_size, seqlen, n_heads, headdim) v = v_orig.expand(batch_size, seqlen, n_heads, headdim) assert k.stride()[2] == 0, "K should have stride=0 in head dim from expand()" assert v.stride()[2] == 0, "V should have stride=0 in head dim from expand()" out = torch.empty_like(q) lse = torch.empty(batch_size, n_heads, seqlen, device=device, dtype=torch.float32) softmax_scale = 1.0 / math.sqrt(headdim) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, softmax_scale=softmax_scale, causal=True, tile_mn=(128, 128), return_lse=True, ) out_fwd, lse_fwd = out_tuple[0], out_tuple[1] assert not torch.isnan(out_fwd).any(), "Forward output contains NaN" assert torch.isfinite(out_fwd).all(), "Forward output contains non-finite values" tensors_for_ref = {"q": q, "k": k, "v": v} tensors_fp32 = {"q": q.float(), "k": k.float(), "v": v.float()} def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx out_ref = compute_reference_flex_attn(tensors_for_ref, causal_mask) out_ref_fp32 = compute_reference_flex_attn(tensors_fp32, causal_mask) fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 pt_error = (out_ref - out_ref_fp32).abs().max().item() cute_error = (out_fwd - out_ref_fp32).abs().max().item() print(f"\nGQA expand stride=0 test:") print(f" Forward: kernel err={cute_error:.2e}, ref err={pt_error:.2e}, atol={fwd_atol:.2e}") assert cute_error <= rtol * pt_error + fwd_atol, ( f"Forward error {cute_error:.2e} exceeds {rtol}x ref error {pt_error:.2e} + {fwd_atol:.2e}" ) grad_out = torch.randn_like(out_fwd) dq, dk, dv = _flash_attn_bwd( q=q, k=k, v=v, out=out_fwd, dout=grad_out, lse=lse_fwd, softmax_scale=softmax_scale, causal=True, m_block_size=128, n_block_size=128, ) assert not torch.isnan(dq).any(), "dQ contains NaN" assert not torch.isnan(dk).any(), "dK contains NaN" assert not torch.isnan(dv).any(), "dV contains NaN" flex_block_mask = create_block_mask( causal_mask, batch_size, n_heads, seqlen, seqlen, device=device, BLOCK_SIZE=(128, 128), ) _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) bwd_rtol = 2 bwd_atol_floor = 1e-5 dq_atol = max(bwd_atol_floor, 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item()) dk_atol = max(bwd_atol_floor, 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item()) dv_atol = max(bwd_atol_floor, 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item()) _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out) pt_dq_err = (dq_pt - dq_ref.to(dtype)).abs().max().item() pt_dk_err = (dk_pt - dk_ref.to(dtype)).abs().max().item() pt_dv_err = (dv_pt - dv_ref.to(dtype)).abs().max().item() cute_dq_err = (dq - dq_ref.to(dtype)).abs().max().item() cute_dk_err = (dk - dk_ref.to(dtype)).abs().max().item() cute_dv_err = (dv - dv_ref.to(dtype)).abs().max().item() print(f" Backward dQ: kernel err={cute_dq_err:.2e}, ref err={pt_dq_err:.2e}, atol={dq_atol:.2e}") print(f" Backward dK: kernel err={cute_dk_err:.2e}, ref err={pt_dk_err:.2e}, atol={dk_atol:.2e}") print(f" Backward dV: kernel err={cute_dv_err:.2e}, ref err={pt_dv_err:.2e}, atol={dv_atol:.2e}") assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" @pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 persistent forward only") def test_persistent_blocksparse_empty_tiles(): """Regression test for persistent forward deadlock with highly-sparse block masks. When most Q-tiles are empty (no active KV blocks), the persistent kernel deadlocked due to barrier phase desync in the empty-tile paths of both the softmax and correction warp groups. """ torch.manual_seed(5) batch_size, nheads_q, nheads_kv = 2, 16, 1 seqlen_q, seqlen_k, headdim = 8192, 128, 128 tile_m, tile_n = 128, 128 dtype = torch.bfloat16 sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m window_size = 64 mask_mod_cute, mask_mod_flex = get_mask_pair( "sliding_window", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size, ) bm = create_block_mask( mask_mod_flex, batch_size, nheads_q, seqlen_q, seqlen_k, device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() block_sparse_mask_fwd = BlockSparseTensorsTorch( mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, block_size=(sparse_tile_m, tile_n), ) q = torch.randn(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype) k = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype) v = torch.randn(batch_size, seqlen_k, nheads_kv, headdim, device="cuda", dtype=dtype) out, lse = _flash_attn_fwd( q=q, k=k, v=v, out=torch.empty(batch_size, seqlen_q, nheads_q, headdim, device="cuda", dtype=dtype), lse=torch.empty(batch_size, nheads_q, seqlen_q, device="cuda", dtype=torch.float32), cu_seqlens_q=None, cu_seqlens_k=None, seqused_q=None, seqused_k=None, page_table=None, softmax_scale=1.0 / math.sqrt(headdim), causal=False, softcap=None, window_size_left=None, window_size_right=None, learnable_sink=None, tile_mn=(tile_m, tile_n), pack_gqa=False, _arch=None, score_mod=None, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, aux_tensors=None, ) torch.cuda.synchronize() assert out.shape == (batch_size, seqlen_q, nheads_q, headdim) assert not out.isnan().any() if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) ================================================ FILE: tests/cute/test_score_mod.py ================================================ import pytest import torch import cutlass import cutlass.cute as cute from cutlass._mlir.dialects import math as mlir_math import operator from torch.nn.attention.flex_attention import flex_attention from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] from score_mod_definitions import ( # TensorSSA-based score mods score_mod_identity as score_mod_1, score_mod_causal as score_mod_2, score_mod_rel_bias as score_mod_3, score_mod_rel_bias_x2 as score_mod_4, score_mod_times_two as score_mod_5, score_mod_alibi as score_mod_6, score_mod_sliding_window as score_mod_7, score_mod_block_diagonal as score_mod_8, score_mod_causal_v2 as score_mod_9, score_mod_batch_bias as score_mod_10, score_mod_dual_buffer as score_mod_11, ) # isort: split from score_mod_definitions import ( score_mod_identity_vectorized as score_mod_1_vectorized, score_mod_causal_vectorized as score_mod_2_vectorized, score_mod_rel_bias as score_mod_3_vectorized, score_mod_rel_bias_x2_vectorized as score_mod_4_vectorized, score_mod_times_two_vectorized as score_mod_5_vectorized, score_mod_alibi_vectorized as score_mod_6_vectorized, score_mod_batch_bias_vectorized as score_mod_10_vectorized, score_mod_dual_buffer_vectorized as score_mod_11_vectorized, ) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, causal_eager as causal_mask_eager, rel_bias_eager as relative_bias_eager, rel_bias_x2_eager as relative_bias_v2_eager, times_two_eager, alibi_eager as alibi_bias_eager, sliding_window_eager, block_diagonal_eager, causal_v2_eager as causal_mask_v2_eager, batch_bias_factory as batch_bias, dual_buffer_factory as dual_buffer_bias, ) COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] # Test pairs: (cute_jit_function, eager_reference_function) TEST_PAIRS = [ (score_mod_1, None), (score_mod_2, causal_mask_eager), (score_mod_3, relative_bias_eager), (score_mod_4, relative_bias_v2_eager), (score_mod_5, times_two_eager), (score_mod_6, alibi_bias_eager), (score_mod_7, sliding_window_eager), (score_mod_8, block_diagonal_eager), (score_mod_9, causal_mask_v2_eager), ] # Test pairs with aux_tensors: (cute_jit_function, eager_reference_function_factory) TEST_PAIRS_WITH_AUX_TENSORS = [ (score_mod_10, batch_bias), (score_mod_11, dual_buffer_bias), ] # Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) TEST_PAIRS_VECTORIZED = [ (score_mod_1, score_mod_1_vectorized), (score_mod_2, score_mod_2_vectorized), (score_mod_3, score_mod_3_vectorized), (score_mod_4, score_mod_4_vectorized), (score_mod_5, score_mod_5_vectorized), (score_mod_6, score_mod_6_vectorized), ] TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED = [ (score_mod_10, score_mod_10_vectorized), (score_mod_11, score_mod_11_vectorized), ] SEQLEN_CONFIGS = [ (1, 1), (64, 128), (128, 192), (256, 256), (239, 1), (799, 3), (113, 203), (113, 128), (128, 217), (113, 211), (108, 256), (256, 512), (384, 256), (640, 128), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (4096, 4096), (4224, 4224), ] VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if COMPUTE_CAPABILITY == 10 else [1, 2] def create_tensors( batch_size=2, num_heads=4, seqlen_q=64, seqlen_kv=64, dim=128, dtype=torch.bfloat16 ): q = torch.randn(batch_size, num_heads, seqlen_q, dim, device="cuda", dtype=dtype) k = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) v = torch.randn(batch_size, num_heads, seqlen_kv, dim, device="cuda", dtype=dtype) return q, k, v def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=False) -> torch.Tensor: q_transposed, k_transposed, v_transposed = map(lambda x: x.transpose(1, 2), (q, k, v)) out = torch.empty_like(q_transposed) _flash_attn_fwd( q_transposed, k_transposed, v_transposed, return_lse=True, score_mod=cute_score_mod, out=out, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return out.transpose(1, 2) def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Tensor: if dtype is not None: q, k, v = q.to(dtype), k.to(dtype), v.to(dtype) return flex_attention(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) def test_cute_vs_flex_attention( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair ): torch.random.manual_seed(42) cute_score_mod, eager_score_mod = score_mod_pair num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype ) if pack_gqa: k = k[:, :num_kv_heads, :, :].clone() v = v[:, :num_kv_heads, :, :].clone() out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) out_cute = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert not torch.isnan(out_pt).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() assert torch.isfinite(out_pt).all() # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 # Calculate actual errors pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() print(f"\nNumerical comparison for {cute_score_mod.__name__}:") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol assert cute_error <= rtol * pt_error + fwd_atol, ( f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_VECTORIZED) def test_cute_score_mod_vectorized( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_vec_pair, ): """Tests equality between original and vectorized versions of score mods""" torch.random.manual_seed(42) cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype ) if pack_gqa: k = k[:, :num_kv_heads, :, :].clone() v = v[:, :num_kv_heads, :, :].clone() out_ref = run_cute_flash(q, k, v, cute_score_mod, pack_gqa=pack_gqa) for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash(q, k, v, cute_vectorized_score_mod, pack_gqa=pack_gqa) assert torch.equal(out, out_ref) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) def test_cute_vs_flex_attention_with_aux_tensors( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_pair ): torch.random.manual_seed(42) cute_score_mod, eager_score_mod_factory = score_mod_pair batch_size = 2 num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( batch_size=batch_size, seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype, ) if pack_gqa: k = k[:, :num_kv_heads, :, :].clone() v = v[:, :num_kv_heads, :, :].clone() if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 aux_tensors = [buffer] eager_score_mod = eager_score_mod_factory(buffer) assert buffer.shape == (batch_size,) elif cute_score_mod == score_mod_11: head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_scale] eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) assert head_bias.shape == (num_q_heads,) assert pos_scale.shape == (seqlen_q,) out_ref_fp32 = run_flex_reference(q, k, v, eager_score_mod, dtype=torch.float32) out_pt = run_flex_reference(q, k, v, eager_score_mod) out_cute = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) # Basic shape and NaN checks assert out_cute.shape == out_ref_fp32.shape == out_pt.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert not torch.isnan(out_pt).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() assert torch.isfinite(out_pt).all() # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 # Calculate actual errors pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() print(f"\nNumerical comparison for {cute_score_mod.__name__}:") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") # Assert that CuTE's error is at most rtol times PyTorch's error + fwd_atol assert cute_error <= rtol * pt_error + fwd_atol, ( f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) @pytest.mark.parametrize("seqlen_q,seqlen_kv", SEQLEN_CONFIGS) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("score_mod_vec_pair", TEST_PAIRS_WITH_AUX_TENSORS_VECTORIZED) def test_cute_score_mod_with_aux_tensors_vectorized( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, dtype, score_mod_vec_pair, ): """Tests equality between original and vectorized versions of score mods""" torch.random.manual_seed(42) cute_score_mod, cute_vectorized_score_mod = score_mod_vec_pair batch_size = 2 num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dtype=dtype ) if pack_gqa: k = k[:, :num_kv_heads, :, :].clone() v = v[:, :num_kv_heads, :, :].clone() if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 aux_tensors = [buffer] assert buffer.shape == (batch_size,) elif cute_score_mod == score_mod_11: head_bias = torch.randn(num_q_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_scale] assert head_bias.shape == (num_q_heads,) assert pos_scale.shape == (seqlen_q,) out_ref = run_cute_flash(q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa) for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) assert torch.equal(out, out_ref) def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): import math from einops import rearrange num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) v_cache_paged = torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache_bshd = rearrange( k_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache_bshd = rearrange( v_cache_paged[page_table.flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] k_cache = k_cache_bshd.transpose(1, 2) v_cache = v_cache_bshd.transpose(1, 2) return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [None, 1, 4, 128]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 2), (4, 2)]) @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ (1, 128), (64, 256), (64, 800), (256, 256), (113, 203), ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS) @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, page_size, dtype, score_mod_pair, ): if COMPUTE_CAPABILITY == 9: pytest.xfail("Paged KV cache only supported on SM100") if page_size is not None and seqlen_kv % page_size != 0: pytest.skip() torch.random.manual_seed(42) cute_score_mod, eager_score_mod = score_mod_pair batch_size = 2 num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 dim = 128 device = "cuda" q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None else: ( k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) from einops import rearrange arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded if pack_gqa: k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) else: k_cache_rep = k_cache v_cache_rep = v_cache def make_masked_score_mod(base_score_mod, seqused_k_tensor): seqused_k_dev = seqused_k_tensor def masked_score_mod(score, b, h, q_idx, kv_idx): if base_score_mod is not None: score = base_score_mod(score, b, h, q_idx, kv_idx) seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) valid_mask = kv_idx < seqlen_limit return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) return masked_score_mod masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) out_ref_fp32 = run_flex_reference( q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 ) out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) q_bshd = q.transpose(1, 2) out_cute = torch.empty_like(q_bshd) if page_size is None: k_bshd = k_cache.transpose(1, 2) v_bshd = v_cache.transpose(1, 2) _flash_attn_fwd( q_bshd, k_bshd, v_bshd, seqused_k=cache_seqlens, return_lse=True, score_mod=cute_score_mod, out=out_cute, lse=None, pack_gqa=pack_gqa, ) else: _flash_attn_fwd( q_bshd, k_cache_paged, v_cache_paged, seqused_k=cache_seqlens, page_table=page_table, return_lse=True, score_mod=cute_score_mod, out=out_cute, lse=None, pack_gqa=pack_gqa, ) out_cute = out_cute.transpose(1, 2) assert out_cute.shape == out_ref_fp32.shape == out_pt.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert not torch.isnan(out_pt).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() assert torch.isfinite(out_pt).all() fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") assert cute_error <= rtol * pt_error + fwd_atol, ( f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [None, 128]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ (64, 128), (128, 256), (256, 256), ], ) @pytest.mark.parametrize("score_mod_pair", TEST_PAIRS_WITH_AUX_TENSORS) @pytest.mark.skipif(COMPUTE_CAPABILITY != 10, reason="Paged KV cache only supported on SM100") def test_score_mod_with_paged_kvcache_aux_tensors( seqlen_q, seqlen_kv, qhead_per_kvhead, num_kv_heads, page_size, dtype, score_mod_pair, ): if COMPUTE_CAPABILITY == 9: pytest.xfail("Paged KV cache only supported on SM100") if page_size is not None and seqlen_kv % page_size != 0: pytest.skip() torch.random.manual_seed(42) cute_score_mod, eager_score_mod_factory = score_mod_pair batch_size = 2 num_q_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 dim = 128 device = "cuda" q = torch.randn(batch_size, num_q_heads, seqlen_q, dim, device=device, dtype=dtype) if page_size is None: k_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) v_cache = torch.randn(batch_size, num_kv_heads, seqlen_kv, dim, device=device, dtype=dtype) page_table = None k_cache_paged = None v_cache_paged = None else: ( k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_kv, page_size, batch_size, num_kv_heads, dim, device, dtype ) cache_seqlens = torch.randint(1, seqlen_kv + 1, (batch_size,), dtype=torch.int32, device=device) if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 aux_tensors = [buffer] eager_score_mod = eager_score_mod_factory(buffer) elif cute_score_mod == score_mod_11: head_bias = torch.randn(num_q_heads, device=device, dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_scale] eager_score_mod = eager_score_mod_factory(head_bias, pos_scale) from einops import rearrange arange = rearrange(torch.arange(seqlen_kv, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded if pack_gqa: k_cache_rep = k_cache.repeat_interleave(qhead_per_kvhead, dim=1) v_cache_rep = v_cache.repeat_interleave(qhead_per_kvhead, dim=1) else: k_cache_rep = k_cache v_cache_rep = v_cache def make_masked_score_mod(base_score_mod, seqused_k_tensor): seqused_k_dev = seqused_k_tensor def masked_score_mod(score, b, h, q_idx, kv_idx): if base_score_mod is not None: score = base_score_mod(score, b, h, q_idx, kv_idx) seqlen_limit = torch.gather(seqused_k_dev, 0, b.long()) valid_mask = kv_idx < seqlen_limit return torch.where(valid_mask, score, torch.full_like(score, float("-inf"))) return masked_score_mod masked_score_mod_fp32 = make_masked_score_mod(eager_score_mod, cache_seqlens) masked_score_mod = make_masked_score_mod(eager_score_mod, cache_seqlens) out_ref_fp32 = run_flex_reference( q, k_cache_rep, v_cache_rep, masked_score_mod_fp32, dtype=torch.float32 ) out_pt = run_flex_reference(q, k_cache_rep, v_cache_rep, masked_score_mod) q_bshd = q.transpose(1, 2) out_cute = torch.empty_like(q_bshd) if page_size is None: k_bshd = k_cache.transpose(1, 2) v_bshd = v_cache.transpose(1, 2) _flash_attn_fwd( q_bshd, k_bshd, v_bshd, seqused_k=cache_seqlens, return_lse=True, score_mod=cute_score_mod, out=out_cute, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) else: _flash_attn_fwd( q_bshd, k_cache_paged, v_cache_paged, seqused_k=cache_seqlens, page_table=page_table, return_lse=True, score_mod=cute_score_mod, out=out_cute, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) out_cute = out_cute.transpose(1, 2) assert out_cute.shape == out_ref_fp32.shape == out_pt.shape assert not torch.isnan(out_cute).any() assert not torch.isnan(out_ref_fp32).any() assert not torch.isnan(out_pt).any() assert torch.isfinite(out_cute).all() assert torch.isfinite(out_ref_fp32).all() assert torch.isfinite(out_pt).all() fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() rtol = 2 pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() print(f"\nNumerical comparison for {cute_score_mod.__name__} (paged={page_size is not None}):") print(f" PyTorch vs FP32 ref max error: {pt_error:.2e}") print(f" CuTE vs FP32 ref max error: {cute_error:.2e}") print(f" Dynamic absolute tolerance: {fwd_atol:.2e}") print(f" Error ratio (CuTE/PyTorch): {cute_error / max(pt_error, 1e-10):.2f}") assert cute_error <= rtol * pt_error + fwd_atol, ( f"CuTE error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) @cute.jit def score_mod_bwd_5(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for score_mod_5 (times_two): d(score*2)/d(score) = 2.""" return grad * cute.full_like(grad, 2.0) @cute.jit def score_mod_bwd_3(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for score_mod_3 (relative_bias): d(score + |q-kv|)/d(score) = 1.""" return grad @cute.jit def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): return grad @cute.jit def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). At unmasked positions (q_idx >= kv_idx), grad passes through. At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. """ return grad @cute.jit def score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Forward: score ** 2.""" return tSrS_ssa * tSrS_ssa @cute.jit def score_mod_bwd_squared(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Backward for score**2: d(score**2)/d(score) = 2*score.""" return grad * cute.full_like(grad, 2.0) * score def score_squared_eager(score, b, h, q_idx, kv_idx): return score * score BWD_TEST_PAIRS = [ (score_mod_5, score_mod_bwd_5, times_two_eager), (score_mod_3, score_mod_bwd_3, relative_bias_eager), (score_mod_squared, score_mod_bwd_squared, score_squared_eager), (score_mod_2, score_mod_bwd_causal, causal_mask_eager), ] BWD_TEST_PAIRS_WITH_AUX = [ (score_mod_10, score_mod_bwd_identity, batch_bias), (score_mod_11, score_mod_bwd_identity, dual_buffer_bias), ] BWD_TEST_PAIRS_PACK_GQA = [ (score_mod_5, score_mod_bwd_5, times_two_eager), (score_mod_3, score_mod_bwd_3, relative_bias_eager), ] def run_cute_flash_bwd( q, k, v, cute_score_mod, cute_score_mod_bwd, aux_tensors=None, pack_gqa=False ): """Run flash attention forward + backward with score_mod.""" q_t = q.transpose(1, 2) k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) out, lse = _flash_attn_fwd( q_t, k_t, v_t, return_lse=True, score_mod=cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) grad_out = torch.randn_like(out) dq, dk, dv = _flash_attn_bwd( q_t, k_t, v_t, out, grad_out, lse, score_mod=cute_score_mod, score_mod_bwd=cute_score_mod_bwd, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return ( out.transpose(1, 2), grad_out.transpose(1, 2), dq.transpose(1, 2), dk.transpose(1, 2), dv.transpose(1, 2), ) def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): """Run flex_attention forward + backward for reference.""" if dtype is not None: q = q.to(dtype).requires_grad_(True) k = k.to(dtype).requires_grad_(True) v = v.to(dtype).requires_grad_(True) grad_out = grad_out.to(dtype) else: q = q.requires_grad_(True) k = k.requires_grad_(True) v = v.requires_grad_(True) compiled_flex = torch.compile(flex_attention) out = compiled_flex(q, k, v, score_mod=eager_score_mod, enable_gqa=q.shape[1] != k.shape[1]) dq, dk, dv = torch.autograd.grad(out, (q, k, v), grad_out) return out, dq, dk, dv @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ (64, 64), (128, 128), (256, 256), (512, 512), (799, 3), (3, 799), (128, 256), (256, 128), (113, 203), ], ) @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS) def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype, score_mod_triple): """Test backward pass with score_mod against flex_attention reference.""" if COMPUTE_CAPABILITY == 9 and dim == 64: pytest.skip("head_dim=64 not supported on SM90 for backward") torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd(q, k, v, cute_fwd, cute_bwd) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any(), "dQ contains NaN" assert not torch.isnan(dk_cute).any(), "dK contains NaN" assert not torch.isnan(dv_cute).any(), "dV contains NaN" rtol = 2 dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() dq_ref = dq_ref_fp32.to(dtype) dk_ref = dk_ref_fp32.to(dtype) dv_ref = dv_ref_fp32.to(dtype) pt_dq_err = (dq_pt - dq_ref).abs().max().item() pt_dk_err = (dk_pt - dk_ref).abs().max().item() pt_dv_err = (dv_pt - dv_ref).abs().max().item() cute_dq_err = (dq_cute - dq_ref).abs().max().item() cute_dk_err = (dk_cute - dk_ref).abs().max().item() cute_dv_err = (dv_cute - dv_ref).abs().max().item() print(f"\nBackward comparison for {cute_fwd.__name__}:") print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, num_heads, batch_size, dtype): if cute_score_mod == score_mod_10: buffer = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 return [buffer], eager_factory(buffer) head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 pos_scale = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 return [head_bias, pos_scale], eager_factory(head_bias, pos_scale) @pytest.mark.parametrize( "seqlen_q,seqlen_kv", [ (64, 64), (128, 128), (256, 128), ], ) @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_WITH_AUX) def test_cute_vs_flex_attention_backward_with_aux( seqlen_q, seqlen_kv, dim, dtype, score_mod_triple ): if COMPUTE_CAPABILITY == 9 and dim == 64: pytest.skip("head_dim=64 not supported on SM90 for backward") torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_factory = score_mod_triple q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=4, dim=dim, dtype=dtype ) aux_tensors, eager_ref = make_aux_tensors_for_bwd( cute_fwd, eager_factory, seqlen_q, q.shape[1], q.shape[0], dtype ) out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( q, k, v, cute_fwd, cute_bwd, aux_tensors=aux_tensors ) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() assert not torch.isnan(dv_cute).any() rtol = 3 dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() dq_ref = dq_ref_fp32.to(dtype) dk_ref = dk_ref_fp32.to(dtype) dv_ref = dv_ref_fp32.to(dtype) pt_dq_err = (dq_pt - dq_ref).abs().max().item() pt_dk_err = (dk_pt - dk_ref).abs().max().item() pt_dv_err = (dv_pt - dv_ref).abs().max().item() cute_dq_err = (dq_cute - dq_ref).abs().max().item() cute_dk_err = (dk_cute - dk_ref).abs().max().item() cute_dv_err = (dv_cute - dv_ref).abs().max().item() print(f"\nBackward comparison with aux for {cute_fwd.__name__}:") print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" @pytest.mark.parametrize("seqlen_q,seqlen_kv", [(128, 128), (128, 256)]) @pytest.mark.parametrize("dim", [64, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("score_mod_triple", BWD_TEST_PAIRS_PACK_GQA) def test_cute_vs_flex_attention_backward_pack_gqa( seqlen_q, seqlen_kv, dim, dtype, qhead_per_kvhead, num_kv_heads, score_mod_triple ): if COMPUTE_CAPABILITY == 9: pytest.xfail("pack_gqa backward not yet implemented on SM90") torch.random.manual_seed(42) cute_fwd, cute_bwd, eager_ref = score_mod_triple num_q_heads = num_kv_heads * qhead_per_kvhead q, k, v = create_tensors( seqlen_q=seqlen_q, seqlen_kv=seqlen_kv, num_heads=num_q_heads, dim=dim, dtype=dtype ) k = k[:, :num_kv_heads, :, :].clone() v = v[:, :num_kv_heads, :, :].clone() out_cute, grad_out, dq_cute, dk_cute, dv_cute = run_cute_flash_bwd( q, k, v, cute_fwd, cute_bwd, pack_gqa=True ) out_ref_fp32, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( q, k, v, eager_ref, grad_out, dtype=torch.float32 ) out_pt, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd(q, k, v, eager_ref, grad_out) assert not torch.isnan(dq_cute).any() assert not torch.isnan(dk_cute).any() assert not torch.isnan(dv_cute).any() rtol = 3 dq_atol = 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item() dk_atol = 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item() dv_atol = 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item() dq_ref = dq_ref_fp32.to(dtype) dk_ref = dk_ref_fp32.to(dtype) dv_ref = dv_ref_fp32.to(dtype) pt_dq_err = (dq_pt - dq_ref).abs().max().item() pt_dk_err = (dk_pt - dk_ref).abs().max().item() pt_dv_err = (dv_pt - dv_ref).abs().max().item() cute_dq_err = (dq_cute - dq_ref).abs().max().item() cute_dk_err = (dk_cute - dk_ref).abs().max().item() cute_dv_err = (dv_cute - dv_ref).abs().max().item() print(f"\nBackward Pack-GQA comparison for {cute_fwd.__name__}:") print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") assert cute_dq_err <= rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" assert cute_dk_err <= rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" assert cute_dv_err <= rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: tests/cute/test_score_mod_varlen.py ================================================ import pytest import torch from torch.nn.attention.flex_attention import flex_attention from flash_attn.cute.interface import _flash_attn_fwd from test_score_mod import _generate_block_kvcache from score_mod_definitions import ( # TensorSSA-based score mods score_mod_alibi, score_mod_batch_bias, score_mod_block_diagonal, score_mod_causal, score_mod_causal_v2, score_mod_debug_global_idx, score_mod_dual_buffer, score_mod_global_kv_bias, score_mod_global_logical_rel_plus_kv_bias, score_mod_global_q_and_kv_bias, score_mod_global_q_bias, score_mod_global_rel_plus_kv_bias, score_mod_identity, score_mod_rel_bias, score_mod_rel_bias_x2, score_mod_sliding_window, score_mod_stress_complex_arithmetic, score_mod_stress_conditional_mask, score_mod_stress_global_offset, score_mod_stress_multi_buffer, score_mod_stress_xor_pattern, score_mod_times_two, ) # isort: split from score_mod_definitions import ( score_mod_identity_vectorized, score_mod_causal_vectorized, score_mod_rel_bias as score_mod_rel_bias_vectorized, score_mod_rel_bias_x2_vectorized, score_mod_times_two_vectorized, score_mod_alibi_vectorized, score_mod_batch_bias_vectorized, score_mod_dual_buffer_vectorized, ) # isort: split from score_mod_definitions import ( # Eager (torch) reference score mods identity_eager, causal_eager, rel_bias_eager, rel_bias_x2_eager, times_two_eager, alibi_eager, sliding_window_eager, block_diagonal_eager, causal_v2_eager, batch_bias_factory, dual_buffer_factory, packed_kv_bias_factory, packed_q_bias_factory, packed_rel_plus_kv_bias_factory, packed_q_and_kv_bias_factory, packed_logical_rel_plus_kv_bias_factory, stress_complex_arithmetic_factory, stress_conditional_mask_factory, stress_multi_buffer_factory, stress_global_offset_factory, stress_xor_pattern_factory, debug_global_idx_factory, ) IS_SM90 = torch.cuda.get_device_capability()[0] == 9 IS_SM100 = torch.cuda.get_device_capability()[0] == 10 # ============================================================================= # Test pairs # ============================================================================= # (cute_score_mod, eager_factory_or_fn, aux_type) # aux_type: None, "batch", "dual_buffer" # All score_mods use 7-arg signature: (tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors) TEST_PAIRS_NO_GLOBAL = [ (score_mod_identity, identity_eager, None), (score_mod_causal, causal_eager, None), (score_mod_rel_bias, rel_bias_eager, None), (score_mod_rel_bias_x2, rel_bias_x2_eager, None), (score_mod_times_two, times_two_eager, None), (score_mod_alibi, alibi_eager, None), (score_mod_sliding_window, sliding_window_eager, None), (score_mod_block_diagonal, block_diagonal_eager, None), (score_mod_causal_v2, causal_v2_eager, None), (score_mod_batch_bias, batch_bias_factory, "batch"), (score_mod_dual_buffer, dual_buffer_factory, "dual_buffer"), ] # Test pairs to compare vectorized score_mods: (cute_jit_function, cute_jit_function_vectorized) TEST_PAIRS_VECTORIZED_NO_GLOBAL = [ (score_mod_identity, score_mod_identity_vectorized, None), (score_mod_causal, score_mod_causal_vectorized, None), (score_mod_rel_bias, score_mod_rel_bias_vectorized, None), (score_mod_rel_bias_x2, score_mod_rel_bias_x2_vectorized, None), (score_mod_times_two, score_mod_times_two_vectorized, None), (score_mod_alibi, score_mod_alibi_vectorized, None), (score_mod_batch_bias, score_mod_batch_bias_vectorized, "batch"), (score_mod_dual_buffer, score_mod_dual_buffer_vectorized, "dual_buffer"), ] # (cute_score_mod, eager_factory, aux_type, requires_global) # aux_type: "kv", "q", "q_and_kv", "q_concat", "kv_with_cu", "multi_buffer" # requires_global: "q" (needs varlen_q), "kv" (needs varlen_k), "both" (needs both) # All score_mods use 7-arg signature and compute global indices from seqlen_info TEST_PAIRS_WITH_GLOBAL = [ (score_mod_global_kv_bias, packed_kv_bias_factory, "kv", "kv"), (score_mod_global_q_bias, packed_q_bias_factory, "q", "q"), (score_mod_global_rel_plus_kv_bias, packed_rel_plus_kv_bias_factory, "kv", "kv"), (score_mod_global_q_and_kv_bias, packed_q_and_kv_bias_factory, "q_and_kv", "both"), ( score_mod_global_logical_rel_plus_kv_bias, packed_logical_rel_plus_kv_bias_factory, "kv", "kv", ), ( score_mod_stress_complex_arithmetic, stress_complex_arithmetic_factory, "q_concat", "q", ), ( score_mod_stress_conditional_mask, stress_conditional_mask_factory, "kv_with_cu", "both", ), ( score_mod_stress_multi_buffer, stress_multi_buffer_factory, "multi_buffer", "both", ), (score_mod_stress_global_offset, stress_global_offset_factory, "kv", "kv"), (score_mod_stress_xor_pattern, stress_xor_pattern_factory, "kv_with_cu", "kv"), (score_mod_debug_global_idx, debug_global_idx_factory, "kv", "kv"), ] SEQLEN_CONFIGS = [ ([1], [1]), ([1, 1], [1, 1]), ([2, 3], [2, 3]), ([8, 16], [8, 16]), ([32, 32], [32, 32]), ([64, 128], [64, 128]), ([64, 56, 128], [64, 56, 128]), ([256, 512], [256, 512]), ([113, 203], [113, 203]), ([239, 1], [239, 1]), ([64], [64]), ([128, 128], [128, 128]), ([32, 32, 32, 32], [32, 32, 32, 32]), ([16, 32, 64, 128, 256], [16, 32, 64, 128, 256]), ([1, 1024], [1, 1024]), ([1024, 1], [1024, 1]), ([1, 256, 1], [1, 256, 1]), ([256, 1, 256], [256, 1, 256]), ([17, 33, 65], [17, 33, 65]), ([64, 128], [32, 64]), ([100, 100], [50, 50]), ([256, 512, 256], [128, 256, 128]), ([2, 1], [16384, 32 * 1024]), ([1, 1], [128 * 1024] * 2), ([2, 1], [8192, 8192]), ([1, 3], [8192, 8192]), ([3, 3], [8192, 8192]), ([128, 128], [8192, 8192]), ([2, 2, 2], [8 * 1024] * 3), ([2, 1], [1024 * 32, 16384]), ([1, 2], [1024 * 32, 16384]), ([1, 1, 1], [128 * 1024] * 3), ([1, 1, 1], [256 * 1024] * 3), ] VEC_SIZES_TO_CHECK_EQUALITY = [1, 2, 4] if IS_SM100 else [1, 2] # ============================================================================= # Helper functions # ============================================================================= def run_cute_flash( q, k, v, score_mod, aux_tensors=None, pack_gqa=False, cu_seqlens_q=None, cu_seqlens_k=None, page_table=None, seqused_k=None, ): """Run CuTE flash attention.""" if cu_seqlens_q is not None or cu_seqlens_k is not None: out = torch.empty_like(q) _flash_attn_fwd( q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqused_k=seqused_k, page_table=page_table, return_lse=True, score_mod=score_mod, out=out, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return out out = torch.empty_like(q) _flash_attn_fwd( q, k, v, seqused_k=seqused_k, page_table=page_table, return_lse=True, score_mod=score_mod, out=out, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) return out def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, dtype=None): """Run flex_attention per-sequence for varlen reference.""" if cu_seqlens_q is not None: num_batches = len(cu_seqlens_q) - 1 else: num_batches = len(cu_seqlens_k) - 1 results = [] for i in range(num_batches): # Get Q slice if cu_seqlens_q is not None: q_slice = ( q[cu_seqlens_q[i] : cu_seqlens_q[i + 1]].unsqueeze(0).transpose(1, 2) ) else: q_slice = q[i : i + 1].transpose(1, 2) # Get K/V slices if cu_seqlens_k is not None: k_slice = ( k[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) ) v_slice = ( v[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].unsqueeze(0).transpose(1, 2) ) else: k_slice = k[i : i + 1].transpose(1, 2) v_slice = v[i : i + 1].transpose(1, 2) if dtype is not None: q_slice, k_slice, v_slice = ( q_slice.to(dtype), k_slice.to(dtype), v_slice.to(dtype), ) def wrapped_mod(score, b, h, q_idx, kv_idx): return score_mod(score, i, h, q_idx, kv_idx) out = flex_attention( q_slice, k_slice, v_slice, score_mod=wrapped_mod, enable_gqa=q_slice.shape[1] != k_slice.shape[1], ) results.append(out.transpose(1, 2).squeeze(0)) return torch.cat(results, dim=0) def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype): """Create Q, K, V tensors and cu_seqlens based on varlen flags.""" batch_size = len(seqlens_q) if varlen_q: total_q = sum(seqlens_q) q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) cu_seqlens_q = torch.tensor( [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), device="cuda", dtype=torch.int32, ) else: seqlen_q = seqlens_q[0] # All sequences have the same length for non-varlen q = torch.randn( batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype ) cu_seqlens_q = None if varlen_k: total_k = sum(seqlens_k) k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) cu_seqlens_k = torch.tensor( [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), device="cuda", dtype=torch.int32, ) else: seqlen_k = seqlens_k[0] # All sequences have the same length for non-varlen k = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype ) v = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype ) cu_seqlens_k = None return q, k, v, cu_seqlens_q, cu_seqlens_k def prepare_ref_tensors( q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q ): """Prepare tensors for flex_attention reference (handle mixed varlen formats).""" num_heads = q.shape[1] if varlen_q else q.shape[2] if not varlen_q and varlen_k: seqlen_q = q.shape[1] q_packed = q.reshape(-1, num_heads, q.shape[-1]) ref_cu_seqlens_q = torch.tensor( [seqlen_q * i for i in range(batch_size + 1)], device="cuda", dtype=torch.int32, ) return q_packed, k, v, ref_cu_seqlens_q, cu_seqlens_k if varlen_q and not varlen_k: return q, k, v, cu_seqlens_q, None return q, k, v, cu_seqlens_q, cu_seqlens_k def check_results( out_cute, out_ref_fp32, out_pt, test_name, rtol=2, extra_atol=1e-4, seqlens_q=None, cu_seqlens_q=None, ): """Compare CuTE output against references.""" assert not torch.isnan(out_cute).any(), f"{test_name}: NaN in output" assert torch.isfinite(out_cute).all(), f"{test_name}: Inf in output" varlen_q = cu_seqlens_q is not None if varlen_q: # Unpack and compare per-sequence assert seqlens_q is not None, "varlen_q requires use of seqlens_q" num_seqs = len(seqlens_q) max_cute_error = 0.0 max_pt_error = 0.0 for i in range(num_seqs): # Extract sequences using cu_seqlens (all outputs are in packed format) start_q = cu_seqlens_q[i] end_q = cu_seqlens_q[i + 1] cute_seq = out_cute[start_q:end_q] ref_seq = out_ref_fp32[start_q:end_q] pt_seq = out_pt[start_q:end_q] max_cute_error = max( max_cute_error, (cute_seq - ref_seq).abs().max().item() ) max_pt_error = max(max_pt_error, (pt_seq - ref_seq).abs().max().item()) cute_error = max_cute_error pt_error = max_pt_error else: # Direct comparison pt_error = (out_pt - out_ref_fp32).abs().max().item() cute_error = (out_cute - out_ref_fp32).abs().max().item() fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() print(f"\n{test_name}:") print(f" PyTorch vs FP32 ref: {pt_error:.2e}") print(f" CuTE vs FP32 ref: {cute_error:.2e}") tol = rtol * pt_error + fwd_atol + extra_atol assert cute_error <= tol, ( f"{test_name}: CuTE error {cute_error:.2e} exceeds tolerance {tol:.2e}" ) # ============================================================================= # Tests # ============================================================================= @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False]) @pytest.mark.parametrize("varlen_k", [True, False]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) @pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) def test_varlen_with_score_mod( seqlens_q, seqlens_k, varlen_q, varlen_k, qhead_per_kvhead, num_kv_heads, dtype, score_mod_tuple, ): """Test varlen attention with score_mod functions that don't use global indices. Covers: both varlen, varlen Q only, varlen K only. Skips: neither varlen """ if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" ) # For non-varlen dimension, all sequences must have same length if not varlen_q: seqlens_q = [seqlens_q[0]] * len(seqlens_q) if not varlen_k: seqlens_k = [seqlens_k[0]] * len(seqlens_k) torch.random.manual_seed(42) cute_score_mod, eager_factory, aux_type = score_mod_tuple num_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 head_dim = 128 batch_size = len(seqlens_q) q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype ) if pack_gqa: if varlen_k: k = k[:, :num_kv_heads, :].clone() v = v[:, :num_kv_heads, :].clone() else: k = k[:, :, :num_kv_heads, :].clone() v = v[:, :, :num_kv_heads, :].clone() aux_tensors = None if aux_type == "batch": bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias) elif aux_type == "dual_buffer": seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_bias] eager_score_mod = eager_factory(head_bias, pos_bias) else: eager_score_mod = eager_factory # Prepare reference tensors q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q ) out_ref_fp32 = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 ) out_pt = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype ) out_cute = run_cute_flash( q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, ) if not varlen_q and varlen_k: seqlen_q = q.shape[1] out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) assert out_cute.shape == out_ref_fp32.shape, ( f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" ) test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k})" extra_atol = 2e-3 check_results( out_cute, out_ref_fp32, out_pt, test_name, extra_atol=extra_atol, seqlens_q=seqlens_q if varlen_q else None, cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False]) @pytest.mark.parametrize("varlen_k", [True, False]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) @pytest.mark.parametrize("score_mod_vec_tuple", TEST_PAIRS_VECTORIZED_NO_GLOBAL) def test_varlen_with_score_mod_vectorized( seqlens_q, seqlens_k, varlen_q, varlen_k, qhead_per_kvhead, num_kv_heads, dtype, score_mod_vec_tuple, ): """Tests equality between original and vectorized versions of score mods""" if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" ) # For non-varlen dimension, all sequences must have same length if not varlen_q: seqlens_q = [seqlens_q[0]] * len(seqlens_q) if not varlen_k: seqlens_k = [seqlens_k[0]] * len(seqlens_k) torch.random.manual_seed(42) cute_score_mod, cute_vectorized_score_mod, aux_type = score_mod_vec_tuple num_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 head_dim = 128 batch_size = len(seqlens_q) q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype ) aux_tensors = None if aux_type == "batch": bias = torch.zeros(batch_size, device="cuda", dtype=dtype) * 0.1 aux_tensors = [bias] elif aux_type == "dual_buffer": seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) head_bias = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.2 pos_bias = torch.arange(seqlen_q, device="cuda", dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_bias] if pack_gqa: if varlen_k: k = k[:, :num_kv_heads, :].clone() v = v[:, :num_kv_heads, :].clone() else: k = k[:, :, :num_kv_heads, :].clone() v = v[:, :, :num_kv_heads, :].clone() out_ref = run_cute_flash( q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, ) for vec_size in VEC_SIZES_TO_CHECK_EQUALITY: cute_vectorized_score_mod.__vec_size__ = vec_size out = run_cute_flash( q, k, v, cute_vectorized_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, ) assert torch.equal(out, out_ref) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("varlen_q", [True, False]) @pytest.mark.parametrize("varlen_k", [True, False]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) @pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) def test_varlen_with_global_idx_score_mod( seqlens_q, seqlens_k, varlen_q, varlen_k, qhead_per_kvhead, num_kv_heads, dtype, score_mod_tuple, ): """Test varlen attention with score_mod functions that use global indices. These score_mods compute q_idx_global and/or kv_idx_global from seqlen_info for packed tensor indexing. Skips tests where required global indices aren't available. """ if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" ) cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple # Skip if score_mod requires global indices we can't provide if requires_global == "q" and not varlen_q: pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") if requires_global == "kv" and not varlen_k: pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") if requires_global == "both" and (not varlen_q or not varlen_k): pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") # For non-varlen dimension, all sequences must have same length if not varlen_q: seqlens_q = [seqlens_q[0]] * len(seqlens_q) if not varlen_k: seqlens_k = [seqlens_k[0]] * len(seqlens_k) torch.random.manual_seed(42) num_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 head_dim = 128 batch_size = len(seqlens_q) max_rel_pos = 512 total_q = sum(seqlens_q) total_k = sum(seqlens_k) cu_seqlens_q = torch.tensor( [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), device="cuda", dtype=torch.int32, ) cu_seqlens_k = torch.tensor( [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), device="cuda", dtype=torch.int32, ) if varlen_q: q = torch.randn(total_q, num_heads, head_dim, device="cuda", dtype=dtype) else: seqlen_q = seqlens_q[0] q = torch.randn( batch_size, seqlen_q, num_heads, head_dim, device="cuda", dtype=dtype ) if varlen_k: k = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) v = torch.randn(total_k, num_heads, head_dim, device="cuda", dtype=dtype) else: seqlen_k = seqlens_k[0] k = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype ) v = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device="cuda", dtype=dtype ) if pack_gqa: if varlen_k: k = k[:, :num_kv_heads, :].clone() v = v[:, :num_kv_heads, :].clone() else: k = k[:, :, :num_kv_heads, :].clone() v = v[:, :, :num_kv_heads, :].clone() # Setup aux tensors based on indexing type if aux_type == "kv": bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_k) elif aux_type == "q": bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_q) elif aux_type == "q_and_kv": q_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 aux_tensors = [q_bias, kv_bias] eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) elif aux_type == "q_concat": bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_q) elif aux_type == "kv_with_cu": kv_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 aux_tensors = [kv_bias] eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) elif aux_type == "multi_buffer": batch_bias = torch.randn(batch_size, device="cuda", dtype=dtype) * 0.1 head_scale = torch.randn(num_heads, device="cuda", dtype=dtype) * 0.1 + 1.0 q_pos_bias = torch.randn(total_q, device="cuda", dtype=dtype) * 0.1 kv_pos_bias = torch.randn(total_k, device="cuda", dtype=dtype) * 0.1 rel_pos_scale = ( torch.randn(max_rel_pos * 2 + 1, device="cuda", dtype=dtype) * 0.1 ) aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] eager_score_mod = eager_factory( batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale, cu_seqlens_q, cu_seqlens_k, max_rel_pos, ) else: raise ValueError(f"Unknown aux_type: {aux_type}") # Prepare reference tensors for flex_attention q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( q, k, v, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q ) out_ref_fp32 = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 ) out_pt = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype ) kernel_cu_seqlens_q = cu_seqlens_q if varlen_q else None kernel_cu_seqlens_k = cu_seqlens_k if varlen_k else None out_cute = run_cute_flash( q, k, v, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, cu_seqlens_q=kernel_cu_seqlens_q, cu_seqlens_k=kernel_cu_seqlens_k, ) if varlen_q: out_ref_final = out_ref_fp32 out_pt_final = out_pt out_cute_final = out_cute else: seqlen_q = seqlens_q[0] out_ref_final = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) out_pt_final = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) out_cute_final = out_cute assert out_cute_final.shape == out_ref_final.shape, ( f"Shape mismatch: {out_cute_final.shape} vs {out_ref_final.shape}" ) test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, {aux_type})" check_results( out_cute_final, out_ref_final, out_pt_final, test_name, extra_atol=1e-3, seqlens_q=seqlens_q if varlen_q else None, cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [None, 128]) @pytest.mark.parametrize("varlen_q", [True, False]) @pytest.mark.parametrize("varlen_k", [True, False]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(4, 2)]) @pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) @pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_NO_GLOBAL) def test_varlen_score_mod_kvcache( seqlens_q, seqlens_k, varlen_q, varlen_k, qhead_per_kvhead, num_kv_heads, page_size, dtype, score_mod_tuple, ): """Test varlen attention with score_mod and paged KV cache.""" if IS_SM90 and page_size is not None: pytest.xfail("paged KV not supported on SM90") if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" ) if page_size is not None and varlen_k: pytest.skip("Paged KV requires batched (non-varlen) K") if not varlen_q: seqlens_q = [seqlens_q[0]] * len(seqlens_q) if not varlen_k: seqlens_k = [seqlens_k[0]] * len(seqlens_k) # Skip if page_size doesn't divide seqlens evenly (for simplicity) if page_size is not None and not varlen_k: if seqlens_k[0] % page_size != 0: pytest.skip("page_size must divide seqlen_k") torch.random.manual_seed(42) cute_score_mod, eager_factory, aux_type = score_mod_tuple num_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 head_dim = 128 batch_size = len(seqlens_q) device = "cuda" # Setup tensors q, k, v, cu_seqlens_q, cu_seqlens_k = setup_tensors( seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, head_dim, dtype ) if pack_gqa: if varlen_k: k = k[:, :num_kv_heads, :].clone() v = v[:, :num_kv_heads, :].clone() else: k = k[:, :, :num_kv_heads, :].clone() v = v[:, :, :num_kv_heads, :].clone() page_table = None k_cache_paged = None v_cache_paged = None k_cache = k v_cache = v if page_size is not None: seqlen_k = seqlens_k[0] ( k_cache_bhsd, v_cache_bhsd, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype ) k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD v_cache = v_cache_bhsd.transpose(1, 2) seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) else: seqused_k = None # Setup aux tensors and eager score_mod aux_tensors = None if aux_type == "batch": bias = torch.zeros(batch_size, device=device, dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias) elif aux_type == "dual_buffer": seqlen_q = seqlens_q[0] if not varlen_q else max(seqlens_q) head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2 pos_bias = torch.arange(seqlen_q, device=device, dtype=dtype) * 0.01 aux_tensors = [head_bias, pos_bias] eager_score_mod = eager_factory(head_bias, pos_bias) else: eager_score_mod = eager_factory # Prepare reference tensors q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( q, k_cache, v_cache, cu_seqlens_q, cu_seqlens_k, varlen_q, varlen_k, batch_size, seqlens_q, ) out_ref_fp32 = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 ) out_pt = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype ) k_input = k_cache_paged if page_size is not None else k_cache v_input = v_cache_paged if page_size is not None else v_cache out_cute = run_cute_flash( q, k_input, v_input, cute_score_mod, aux_tensors=aux_tensors, pack_gqa=pack_gqa, cu_seqlens_q=cu_seqlens_q if varlen_q else None, cu_seqlens_k=cu_seqlens_k if (varlen_k and page_size is None) else None, page_table=page_table if page_size is not None else None, seqused_k=seqused_k if page_size is not None else None, ) if not varlen_q and varlen_k: seqlen_q = q.shape[1] out_ref_fp32 = out_ref_fp32.reshape(batch_size, seqlen_q, num_heads, head_dim) out_pt = out_pt.reshape(batch_size, seqlen_q, num_heads, head_dim) assert out_cute.shape == out_ref_fp32.shape, ( f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" ) test_name = f"{cute_score_mod.__name__} (varlen_q={varlen_q}, varlen_k={varlen_k}, paged={page_size is not None})" extra_atol = 2e-3 check_results( out_cute, out_ref_fp32, out_pt, test_name, extra_atol=extra_atol, seqlens_q=seqlens_q if varlen_q else None, cu_seqlens_q=cu_seqlens_q if varlen_q else None, ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("page_size", [None, 128]) @pytest.mark.parametrize("varlen_q", [True, False]) @pytest.mark.parametrize("varlen_k", [True, False]) @pytest.mark.parametrize("qhead_per_kvhead,num_kv_heads", [(1, 1), (4, 2)]) @pytest.mark.parametrize("seqlens_q,seqlens_k", SEQLEN_CONFIGS) @pytest.mark.parametrize("score_mod_tuple", TEST_PAIRS_WITH_GLOBAL) def test_varlen_score_mod_with_paged_kvcache_global( seqlens_q, seqlens_k, varlen_q, varlen_k, qhead_per_kvhead, num_kv_heads, page_size, dtype, score_mod_tuple, ): """Test varlen attention with global idx score_mod and paged KV cache.""" if IS_SM90 and page_size is not None: pytest.xfail("paged KV not supported on SM90") if page_size is not None and varlen_k: pytest.skip("Paged KV cache requires batched (non-varlen) K") if not varlen_q and not varlen_k: pytest.skip( "At least one of varlen_q or varlen_k must be True for varlen tests" ) if not varlen_q: seqlens_q = [seqlens_q[0]] * len(seqlens_q) if not varlen_k: seqlens_k = [seqlens_k[0]] * len(seqlens_k) if page_size is not None and not varlen_k: if seqlens_k[0] % page_size != 0: pytest.skip("page_size must divide seqlen_k") cute_score_mod, eager_factory, aux_type, requires_global = score_mod_tuple if requires_global == "q" and not varlen_q: pytest.skip(f"{cute_score_mod.__name__} requires varlen_q for q_idx_global") if requires_global == "kv" and not varlen_k: pytest.skip(f"{cute_score_mod.__name__} requires varlen_k for kv_idx_global") if requires_global == "both" and (not varlen_q or not varlen_k): pytest.skip(f"{cute_score_mod.__name__} requires both varlen_q and varlen_k") torch.random.manual_seed(42) num_heads = num_kv_heads * qhead_per_kvhead pack_gqa = qhead_per_kvhead > 1 head_dim = 128 batch_size = len(seqlens_q) max_rel_pos = 512 device = "cuda" total_q = sum(seqlens_q) total_k = sum(seqlens_k) cu_seqlens_q = torch.tensor( [0] + list(torch.tensor(seqlens_q).cumsum(0).tolist()), device=device, dtype=torch.int32, ) cu_seqlens_k = torch.tensor( [0] + list(torch.tensor(seqlens_k).cumsum(0).tolist()), device=device, dtype=torch.int32, ) cu_seqlens_k_for_kernel = cu_seqlens_k if varlen_k else None q = torch.randn(total_q, num_heads, head_dim, device=device, dtype=dtype) if varlen_k: k = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) v = torch.randn(total_k, num_heads, head_dim, device=device, dtype=dtype) else: seqlen_k = seqlens_k[0] k = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype ) v = torch.randn( batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype ) if pack_gqa: if varlen_k: k = k[:, :num_kv_heads, :].clone() v = v[:, :num_kv_heads, :].clone() else: k = k[:, :, :num_kv_heads, :].clone() v = v[:, :, :num_kv_heads, :].clone() page_table = None k_cache_paged = None v_cache_paged = None k_cache = k v_cache = v if page_size is not None: seqlen_k = seqlens_k[0] ( k_cache_bhsd, v_cache_bhsd, page_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, page_size, batch_size, num_kv_heads, head_dim, device, dtype ) k_cache = k_cache_bhsd.transpose(1, 2) # BHSD -> BSHD v_cache = v_cache_bhsd.transpose(1, 2) seqused_k = torch.tensor(seqlens_k, dtype=torch.int32, device=device) else: seqused_k = None if aux_type == "kv": bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_k) elif aux_type == "q": bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_q) elif aux_type == "q_and_kv": q_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 aux_tensors = [q_bias, kv_bias] eager_score_mod = eager_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqlens_k) elif aux_type == "q_concat": bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 aux_tensors = [bias] eager_score_mod = eager_factory(bias, cu_seqlens_q) elif aux_type == "kv_with_cu": kv_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 aux_tensors = [kv_bias] eager_score_mod = eager_factory(kv_bias, cu_seqlens_q, cu_seqlens_k) elif aux_type == "multi_buffer": batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 head_scale = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 + 1.0 q_pos_bias = torch.randn(total_q, device=device, dtype=dtype) * 0.1 kv_pos_bias = torch.randn(total_k, device=device, dtype=dtype) * 0.1 rel_pos_scale = ( torch.randn(max_rel_pos * 2 + 1, device=device, dtype=dtype) * 0.1 ) aux_tensors = [batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale] eager_score_mod = eager_factory( batch_bias, head_scale, q_pos_bias, kv_pos_bias, rel_pos_scale, cu_seqlens_q, cu_seqlens_k, max_rel_pos, ) else: raise ValueError(f"Unknown aux_type: {aux_type}") q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k = prepare_ref_tensors( q, k_cache, v_cache, cu_seqlens_q, cu_seqlens_k, True, varlen_k, batch_size, seqlens_q, ) out_ref_fp32 = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=torch.float32 ) out_pt = run_flex_varlen_ref( q_ref, k_ref, v_ref, ref_cu_q, ref_cu_k, eager_score_mod, dtype=dtype ) # Run CuTE k_input = k_cache_paged if page_size is not None else k_cache v_input = v_cache_paged if page_size is not None else v_cache out_cute = torch.empty_like(q) _flash_attn_fwd( q, k_input, v_input, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k_for_kernel if page_size is None else None, seqused_k=seqused_k if page_size is not None else None, page_table=page_table, return_lse=True, score_mod=cute_score_mod, out=out_cute, lse=None, aux_tensors=aux_tensors, pack_gqa=pack_gqa, ) assert out_cute.shape == out_ref_fp32.shape, ( f"Shape mismatch: {out_cute.shape} vs {out_ref_fp32.shape}" ) test_name = f"{cute_score_mod.__name__} (paged={page_size is not None}, {aux_type})" check_results( out_cute, out_ref_fp32, out_pt, test_name, extra_atol=1e-3, seqlens_q=seqlens_q, cu_seqlens_q=cu_seqlens_q, ) if __name__ == "__main__": pytest.main([__file__, "-v"]) ================================================ FILE: tests/cute/test_utils.py ================================================ """Unit tests for flash_attn.cute.utils module.""" import functools from flash_attn.cute import utils as cute_utils from flash_attn.cute.utils import hash_callable class TestHashCallable: """Tests for hash_callable function.""" def test_returns_cute_hash_when_set_on_function(self): """hash_callable should return __cute_hash__ immediately when set on function.""" def my_func(): pass my_func.__cute_hash__ = "precomputed-hash-123" result = hash_callable(my_func) assert result == "precomputed-hash-123" def test_returns_cute_hash_from_wrapped_function(self): """hash_callable should check __wrapped__ for __cute_hash__.""" def inner_func(): pass inner_func.__cute_hash__ = "inner-hash-456" # Simulate a decorator that sets __wrapped__ @functools.wraps(inner_func) def wrapper_func(): return inner_func() result = hash_callable(wrapper_func) assert result == "inner-hash-456" def test_prefers_wrapper_cute_hash_over_wrapped(self): """When both wrapper and wrapped have __cute_hash__, prefer wrapper.""" def inner_func(): pass inner_func.__cute_hash__ = "inner-hash" @functools.wraps(inner_func) def wrapper_func(): return inner_func() wrapper_func.__cute_hash__ = "wrapper-hash" result = hash_callable(wrapper_func) assert result == "wrapper-hash" def test_fallback_to_source_hashing(self): """hash_callable should fall back to source hashing when no __cute_hash__.""" def my_func(): return 42 result = hash_callable(my_func) # Should return a hex string (SHA256 hash) assert isinstance(result, str) assert len(result) == 64 # SHA256 produces 64 hex chars def test_same_function_produces_same_hash(self): """Same function should produce consistent hash.""" def my_func(): return 42 hash1 = hash_callable(my_func) hash2 = hash_callable(my_func) assert hash1 == hash2 def test_different_functions_produce_different_hashes(self): """Different functions should produce different hashes.""" def func_a(): return 1 def func_b(): return 2 hash_a = hash_callable(func_a) hash_b = hash_callable(func_b) assert hash_a != hash_b def test_fast_path_skips_expensive_hashing(self): """When __cute_hash__ is set, expensive operations should be skipped.""" def my_func(): pass my_func.__cute_hash__ = "fast-hash" # Mock at module level since we loaded it directly original_getsource = cute_utils.inspect.getsource call_tracker = {"getsource": 0, "sha256": 0} def tracking_getsource(*args, **kwargs): call_tracker["getsource"] += 1 return original_getsource(*args, **kwargs) original_sha256 = cute_utils.hashlib.sha256 def tracking_sha256(*args, **kwargs): call_tracker["sha256"] += 1 return original_sha256(*args, **kwargs) cute_utils.inspect.getsource = tracking_getsource cute_utils.hashlib.sha256 = tracking_sha256 try: result = hash_callable(my_func) finally: cute_utils.inspect.getsource = original_getsource cute_utils.hashlib.sha256 = original_sha256 # Neither inspect.getsource nor hashlib.sha256 should be called assert call_tracker["getsource"] == 0, "getsource should not be called" assert call_tracker["sha256"] == 0, "sha256 should not be called" assert result == "fast-hash" def test_fast_path_on_wrapped_skips_expensive_hashing(self): """When __cute_hash__ is on __wrapped__, expensive operations should be skipped.""" def inner_func(): pass inner_func.__cute_hash__ = "wrapped-fast-hash" @functools.wraps(inner_func) def wrapper_func(): return inner_func() # Mock at module level original_getsource = cute_utils.inspect.getsource call_tracker = {"getsource": 0, "sha256": 0} def tracking_getsource(*args, **kwargs): call_tracker["getsource"] += 1 return original_getsource(*args, **kwargs) original_sha256 = cute_utils.hashlib.sha256 def tracking_sha256(*args, **kwargs): call_tracker["sha256"] += 1 return original_sha256(*args, **kwargs) cute_utils.inspect.getsource = tracking_getsource cute_utils.hashlib.sha256 = tracking_sha256 try: result = hash_callable(wrapper_func) finally: cute_utils.inspect.getsource = original_getsource cute_utils.hashlib.sha256 = original_sha256 assert call_tracker["getsource"] == 0, "getsource should not be called" assert call_tracker["sha256"] == 0, "sha256 should not be called" assert result == "wrapped-fast-hash" def test_closure_values_affect_hash(self): """Functions with different closure values should have different hashes.""" value1 = 10 value2 = 20 def make_func(val): def inner(): return val return inner func1 = make_func(value1) func2 = make_func(value2) hash1 = hash_callable(func1) hash2 = hash_callable(func2) assert hash1 != hash2 class TestHashCallableIntegration: """Integration tests for hash_callable with flash attention.""" def test_repeated_calls_use_cached_hash(self): """Repeated calls with same score_mod should use cached/fast hash path.""" def score_mod(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, aux_tensors): return tSrS_ssa # Set __cute_hash__ to simulate Inductor-generated code score_mod.__cute_hash__ = "inductor-generated-hash" original_getsource = cute_utils.inspect.getsource call_count = [0] # Use list for mutable counter in nested function def counting_getsource(*args, **kwargs): call_count[0] += 1 return original_getsource(*args, **kwargs) cute_utils.inspect.getsource = counting_getsource try: # Call hash_callable multiple times hash1 = hash_callable(score_mod) hash2 = hash_callable(score_mod) hash3 = hash_callable(score_mod) finally: cute_utils.inspect.getsource = original_getsource # getsource should never be called because __cute_hash__ is set assert call_count[0] == 0, f"getsource was called {call_count[0]} times" assert hash1 == hash2 == hash3 == "inductor-generated-hash" ================================================ FILE: tests/layers/test_rotary.py ================================================ # Copyright (c) 2023, Tri Dao. import math import pytest import torch import torch.nn.functional as F from einops import rearrange from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_ from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX from transformers.models.gpt_neox.modeling_gpt_neox import ( apply_rotary_pos_emb as apply_rotary_pos_emb_neox, ) from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj from transformers.models.gptj.modeling_gptj import fixed_pos_embedding # NeoX-style rotary embedding @pytest.mark.parametrize("seqlen_offset", [0, 711]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0]) def test_rotary(rotary_emb_fraction, seqlen_offset): device = "cuda" dtype = torch.float16 rtol, atol = (1e-3, 5e-3) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen_total = 2048 seqlen = seqlen_total - seqlen_offset nheads = 16 headdim = 128 rotary_dim = int(headdim * rotary_emb_fraction) qkv = torch.randn( batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True ) qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace rotary = RotaryEmbedding(rotary_dim, device=device) rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device) # Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total) cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype) q_pt = ( rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d") .detach() .clone() .requires_grad_(True) ) k_pt = ( rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d") .detach() .clone() .requires_grad_(True) ) q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset) out = rotary(qkv, seqlen_offset=seqlen_offset) assert torch.allclose( rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol ) assert torch.allclose( rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol ) assert torch.allclose( rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol ) assert torch.allclose( rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol ) assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) g = torch.randn_like(out) g_og = g.clone().detach() # Our implementation modifies g inplace out.backward(g) q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")) k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")) assert torch.allclose( rearrange(q_pt.grad, "b h s d -> b s h d"), qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol, ) assert torch.allclose( rearrange(k_pt.grad, "b h s d -> b s h d"), qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol, ) assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2]) # GPT-J-style rotary embedding @pytest.mark.parametrize("seqlen_offset", [0, 711]) @pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0]) def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): device = "cuda" dtype = torch.float16 rtol, atol = (1e-3, 5e-3) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen_total = 2048 seqlen = seqlen_total - seqlen_offset nheads = 16 headdim = 128 rotary_dim = int(headdim * rotary_emb_fraction) qkv = torch.randn( batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True ) qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device) sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total) sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj) q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True) k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True) q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset) k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset) out = rotary(qkv, seqlen_offset=seqlen_offset) assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol) assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol) assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(out[:, :, 2], qkv_og[:, :, 2]) g = torch.randn_like(out) g_og = g.clone().detach() # Our implementation modifies g inplace out.backward(g) q_gptj.backward(g_og[:, :, 0, :, :rotary_dim]) k_gptj.backward(g_og[:, :, 1, :, :rotary_dim]) assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol) assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol) assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:]) assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2]) ================================================ FILE: tests/losses/test_cross_entropy.py ================================================ # Copyright (c) 2024, Tri Dao. import pytest import torch import torch.nn.functional as F from flash_attn.losses.cross_entropy import CrossEntropyLoss is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize( "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("precompute_lse", [False, True]) # @pytest.mark.parametrize("precompute_lse", [False]) @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) @pytest.mark.parametrize("return_z_loss", [False, True]) # @pytest.mark.parametrize("lse_square_scale", [1e-2]) @pytest.mark.parametrize("logit_scale", [1.0, 0.7]) # @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) @pytest.mark.parametrize("vocab_size", [50257, 128256]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [12]) def test_cross_entropy_loss( vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, precompute_lse, dtype, ): if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0): pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0") device = "cuda" rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) # set seed torch.random.manual_seed(0) batch_size = 1 if dtype == torch.float32 else 4 # Otherwise OOM seqlen = 4096 if lse_square_scale == 0.0 and logit_scale == 1.0 else 1024 # Otherwise OOM x_pt = torch.randn( batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True ) x = x_pt.detach().clone().requires_grad_() y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) if batch_size * seqlen > 10: y[torch.randperm(batch_size * seqlen)[:10]] = -100 model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) model = CrossEntropyLoss( label_smoothing=smoothing, logit_scale=logit_scale, lse_square_scale=lse_square_scale, return_z_loss=return_z_loss, inplace_backward=inplace_backward, ) if precompute_lse: with torch.no_grad(): lse = torch.logsumexp(x.float(), dim=-1) else: lse = None if return_z_loss: out, out_z_loss = model(x, y, precomputed_lse=lse) else: out = model(x, y, precomputed_lse=lse) x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float() out_pt = model_pt(x_pt_scaled, y) if lse_square_scale > 0.0: lse_pt = torch.logsumexp(x_pt_scaled, dim=-1) z_loss_pt = lse_square_scale * (lse_pt[y != -100] ** 2).mean() if return_z_loss: assert torch.allclose(out_z_loss, z_loss_pt, rtol=rtol, atol=atol) out_pt += z_loss_pt assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) g = torch.randn_like(out) out_pt.backward(g) out.backward(g) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) ================================================ FILE: tests/losses/test_cross_entropy_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py import math import pytest import torch from apex.transformer import parallel_state, tensor_parallel from flash_attn.losses.cross_entropy import CrossEntropyLoss is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize( "dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("precompute_lse", [False, True]) # @pytest.mark.parametrize("precompute_lse", [False]) @pytest.mark.parametrize("inplace_backward", [False, True]) # @pytest.mark.parametrize("inplace_backward", [False]) # @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) @pytest.mark.parametrize("lse_square_scale", [1e-2]) @pytest.mark.parametrize("logit_scale", [1.0, 0.7]) # @pytest.mark.parametrize("logit_scale", [1.0]) @pytest.mark.parametrize("smoothing", [0.0, 0.9]) # @pytest.mark.parametrize("smoothing", [0.0]) @pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split # @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split # @pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [2]) def test_cross_entropy_loss_parallel( vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, precompute_lse, dtype, ): if precompute_lse and (logit_scale != 1.0 or smoothing != 0.0): pytest.skip("precompute_lse only works with logit_scale=1.0 and smoothing=0.0") assert vocab_size % world_size == 0 rtol, atol = ( (1e-5, 2e-5) if dtype == torch.float32 else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3)) ) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") partition_vocab_size = vocab_size // world_size device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 128 x_pt = ( torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10 ).requires_grad_() x = ( tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device) y[torch.randperm(batch_size * seqlen)[:10]] = -100 model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none") model = CrossEntropyLoss( label_smoothing=smoothing, logit_scale=logit_scale, reduction="none", lse_square_scale=lse_square_scale, inplace_backward=inplace_backward, process_group=parallel_state.get_tensor_model_parallel_group(), ) if precompute_lse: with torch.no_grad(): lse = torch.logsumexp(x.float(), dim=-1) else: lse = None out = model(x, y, precomputed_lse=lse) out_pt = model_pt(x_pt.float() * logit_scale, y) if lse_square_scale > 0.0: lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1) out_pt += lse_square_scale * lse_pt.square() out_pt.masked_fill_(y == -100, 0.0) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) g = torch.randn_like(out) out_pt.backward(g) out.backward(g) assert torch.allclose( x.grad, x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size], rtol=rtol, atol=atol, ) parallel_state.destroy_model_parallel() ================================================ FILE: tests/models/test_baichuan.py ================================================ # Copyright (c) 2023, Tri Dao. import os import time from pathlib import Path import torch import pytest from einops import rearrange from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from flash_attn.models.gpt import ( GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp, ) from flash_attn.models.baichuan import ( remap_state_dict_hf_baichuan, baichuan_config_to_gpt2_config, ) from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import update_graph_cache @pytest.mark.parametrize( "model_name", [ "baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base", "baichuan-inc/Baichuan2-7B-Base", "baichuan-inc/Baichuan2-13B-Base", ], ) def test_baichuan_state_dict(model_name): config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize( "model_name", [ "baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base", "baichuan-inc/Baichuan2-7B-Base", "baichuan-inc/Baichuan2-13B-Base", ], ) def test_baichuan_optimized(model_name): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True, ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.model(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward" @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize( "model_name", [ "baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base", "baichuan-inc/Baichuan2-7B-Base", "baichuan-inc/Baichuan2-13B-Base", ], ) def test_baichuan_parallel_forward(model_name, world_size): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ from apex.transformer import parallel_state dtype = torch.float16 config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) out, _ = all_gather_raw(out, process_group=process_group) out = rearrange(out, "(b s) d -> b s d", b=batch_size) logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) del model parallel_state.destroy_model_parallel() if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 2 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize( "model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"] ) def test_baichuan_generation(model_name): dtype = torch.float16 device = "cuda" config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 2048 max_length = 2048 + 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() model(input_ids) # Warm up print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation" @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) def test_baichuan_parallel_generation(model_name, world_size): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ from apex.transformer import parallel_state dtype = torch.float16 config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = False config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) pretrained_state_dict = remap_state_dict_hf_baichuan( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() print("Without CUDA graph") out = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") out_cg = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, cg=True, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) del model parallel_state.destroy_model_parallel() if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() with torch.inference_mode(): out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.inference_mode(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) ================================================ FILE: tests/models/test_bert.py ================================================ import re from collections import OrderedDict import pytest import torch import torch.nn.functional as F from einops import rearrange from transformers import BertConfig from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF from transformers.models.bert.modeling_bert import BertModel as BertModelHF from flash_attn.models.bert import ( BertForPreTraining, BertModel, inv_remap_state_dict, remap_state_dict, ) from flash_attn.utils.pretrained import state_dict_from_pretrained @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) def test_bert_state_dict(model_name): config = BertConfig.from_pretrained(model_name) pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config) model = BertForPreTraining(config) state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape def get_hf_models(model_name, config, dtype): pretrained_state_dict = state_dict_from_pretrained(model_name) def key_mapping_ln_gamma_beta(key): key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key) key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key) return key pretrained_state_dict = OrderedDict( (key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items() ) model_hf = BertForPreTrainingHF(config) # Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias" # position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias. model_hf.load_state_dict(pretrained_state_dict, strict=False) model_hf.cuda().to(dtype=dtype) return model_hf @pytest.mark.parametrize("model_name", ["bert-base-uncased"]) def test_bert_non_optimized(model_name): """Check that our implementation of BERT (without any optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 config = BertConfig.from_pretrained(model_name) model = BertForPreTraining.from_pretrained(model_name, config) model = model.cuda().to(dtype=dtype) model_ref = get_hf_models(model_name, config, torch.float32) model_hf = get_hf_models(model_name, config, dtype) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 4 max_seqlen = 512 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None] input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) out = model.bert(input_ids, attention_mask=attention_mask) sequence_output, pooled_output = out.last_hidden_state, out.pooler_output out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output out_ref = model_ref.bert(input_ids, attention_mask=attention_mask) sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}") print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}") assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * ( sequence_output_hf - sequence_output_ref ).abs().max().item() assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * ( pooled_output_hf - pooled_output_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) def test_bert_optimized(model_name): """Check that our implementation of BERT (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 config = BertConfig.from_pretrained(model_name) # Our implementation of fused_mlp assumes the activation is # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh". # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True model = BertForPreTraining.from_pretrained(model_name, config) model = model.cuda().to(dtype=dtype) model_ref = get_hf_models(model_name, config, torch.float32) model_hf = get_hf_models(model_name, config, dtype) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 4 max_seqlen = 512 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None] input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) out = model.bert(input_ids, attention_mask=attention_mask) sequence_output, pooled_output = out.last_hidden_state, out.pooler_output out_hf = model_hf.bert(input_ids, attention_mask=attention_mask) sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output # Need to zero out the padded tokens in the sequence before comparison. sequence_output_hf[~attention_mask, :] = 0.0 out_ref = model_ref.bert(input_ids, attention_mask=attention_mask) sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output sequence_output_ref[~attention_mask, :] = 0.0 print( f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}" ) print( f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}" ) print( f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}" ) print( f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}" ) assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * ( sequence_output_hf - sequence_output_ref ).abs().max().item() assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * ( pooled_output_hf - pooled_output_ref ).abs().max().item() out = model(input_ids, attention_mask=attention_mask) prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits # Need to zero out the padded tokens in the sequence before comparison. prediction_scores = prediction_scores.clone() prediction_scores[~attention_mask, :] = 0.0 out_hf = model_hf(input_ids, attention_mask=attention_mask) prediction_scores_hf, seq_relationship_scores_hf = ( out_hf.prediction_logits, out_hf.seq_relationship_logits, ) prediction_scores_hf[~attention_mask, :] = 0.0 out_ref = model_ref(input_ids, attention_mask=attention_mask) prediction_scores_ref, seq_relationship_scores_ref = ( out_ref.prediction_logits, out_ref.seq_relationship_logits, ) prediction_scores_ref[~attention_mask, :] = 0.0 print( f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}" ) print( f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}" ) print( f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}" ) print( f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}" ) assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * ( prediction_scores_hf - prediction_scores_ref ).abs().max().item() assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * ( seq_relationship_scores_hf - seq_relationship_scores_ref ).abs().max().item() @pytest.mark.parametrize("last_layer_subset", [False, True]) # @pytest.mark.parametrize('last_layer_subset', [True]) @pytest.mark.parametrize("has_key_padding_mask", [True, False]) # @pytest.mark.parametrize('has_key_padding_mask', [True]) @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) # @pytest.mark.parametrize('model_name', ["bert-base-uncased"]) def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset): """Check that our implementation of BERT (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 config = BertConfig.from_pretrained(model_name) # Our implementation of fused_mlp assumes the activation is # nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh". # If you just want "gelu", disable fused_mlp. config.hidden_act = "gelu_new" config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config.dense_seq_output = True config.last_layer_subset = last_layer_subset config.use_xentropy = True model = BertForPreTraining.from_pretrained(model_name, config) model = model.cuda().to(dtype=dtype) model_ref = get_hf_models(model_name, config, torch.float32) model_hf = get_hf_models(model_name, config, dtype) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 4 max_seqlen = 512 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") if has_key_padding_mask: attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None] else: attention_mask = None input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) labels = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) if attention_mask is not None: labels[~attention_mask] = 0 labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0 masked_tokens_mask = labels.flatten() > 0 next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda") out = model( input_ids, attention_mask=attention_mask, labels=labels, next_sentence_label=next_sequence_label, ) prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits out_hf = model_hf( input_ids, attention_mask=attention_mask, labels=labels, next_sentence_label=next_sequence_label, ) prediction_scores_hf, seq_relationship_scores_hf = ( out_hf.prediction_logits, out_hf.seq_relationship_logits, ) prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask] out_ref = model_ref( input_ids, attention_mask=attention_mask, labels=labels, next_sentence_label=next_sequence_label, ) prediction_scores_ref, seq_relationship_scores_ref = ( out_ref.prediction_logits, out_ref.seq_relationship_logits, ) prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask] print( f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}" ) print( f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}" ) print( f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}" ) print( f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}" ) assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * ( prediction_scores_hf - prediction_scores_ref ).abs().max().item() assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * ( seq_relationship_scores_hf - seq_relationship_scores_ref ).abs().max().item() # The loss calculation from HF is wrong: it doesn't ignore the labels that are 0. # assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item() @pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"]) def test_inv_remap_state_dict(model_name: str): """ Verify that we can convert a HF BERT model to flash_attn and back. """ state_dict = state_dict_from_pretrained(model_name) config = BertConfig.from_pretrained(model_name) flash_state_dict = remap_state_dict(state_dict, config) recovered_state_dict = inv_remap_state_dict(flash_state_dict, config) assert set(state_dict.keys()) == set(recovered_state_dict.keys()) for k in state_dict.keys(): assert state_dict[k].shape == recovered_state_dict[k].shape torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6) ================================================ FILE: tests/models/test_bigcode.py ================================================ import time import pytest import torch from transformers import AutoTokenizer, GPTBigCodeConfig from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) def test_bigcode_state_dict(model_name): config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) pretrained_state_dict = remap_state_dict_hf_bigcode( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device="meta") state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) def test_bigcode_optimized(model_name): """Check that our implementation of BigCode (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state logits_ref = model_ref(input_ids).logits del model_ref model_hf = GPTBigCodeForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device} ) model_hf.eval() out_hf = model_hf.transformer(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) def test_bigcode_generation(model_name): """Check that our implementation of BigCode (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name)) config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = True tokenizer = AutoTokenizer.from_pretrained(model_name) eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = GPTBigCodeForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device} ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error @pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"]) def test_inv_remap_state_dict(model_name: str): """ Verify that we can convert a HF BigCode model to flash_attn and back. """ state_dict = state_dict_from_pretrained(model_name) config = GPTBigCodeConfig.from_pretrained(model_name) flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config) recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config) assert set(state_dict.keys()) == set(recovered_state_dict.keys()) for k in state_dict.keys(): assert state_dict[k].shape == recovered_state_dict[k].shape torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6) ================================================ FILE: tests/models/test_btlm.py ================================================ # Copyright (c) 2023, Tri Dao. import time import torch import pytest from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.btlm import btlm_config_to_gpt2_config, remap_state_dict_hf_btlm from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import update_graph_cache @pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) def test_btlm_state_dict(model_name): config = btlm_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) def test_btlm_optimized(model_name): """Check that our implementation of Btlm (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = btlm_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.fused_bias_fc = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True, ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.transformer(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) def test_btlm_generation(model_name): dtype = torch.float16 device = "cuda" config = btlm_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.fused_bias_fc = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 2048 max_length = 2048 + 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() model(input_ids) # Warm up print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) @pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) def test_btlm_init(model_name): dtype = torch.float32 device = "cuda" btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config = btlm_config_to_gpt2_config(btlm_config) model = GPTLMHeadModel(config, device=device, dtype=dtype) model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device) assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4 assert ( model.transformer.embeddings.word_embeddings.weight.std() - model_ref.transformer.wte.weight.std() ).abs() < 1e-4 assert model.lm_head.weight.mean().abs() < 1e-4 assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4 for l in range(config.n_layer): assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4 assert ( model.transformer.layers[l].mixer.Wqkv.weight.std() - model_ref.transformer.h[l].attn.c_attn.weight.std() ).abs() < 1e-4 assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0 assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4 assert ( model.transformer.layers[l].mixer.out_proj.weight.std() - model_ref.transformer.h[l].attn.c_proj.weight.std() ).abs() < 1e-4 assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0 assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4 assert ( model.transformer.layers[l].mlp.fc1.weight.std() - model_ref.transformer.h[l].mlp.c_fc.weight.std() ).abs() < 1e-4 assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0 assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4 assert ( model.transformer.layers[l].mlp.fc2.weight.std() - model_ref.transformer.h[l].mlp.c_proj.weight.std() ).abs() < 1e-4 assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0 ================================================ FILE: tests/models/test_falcon.py ================================================ # Copyright (c) 2023, Tri Dao. import os import time from pathlib import Path current_dir = Path(__file__).parent.absolute() import pytest import torch from einops import rearrange from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"]) def test_falcon_state_dict(model_name): config = falcon_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) pretrained_state_dict = remap_state_dict_hf_falcon( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"]) def test_falcon_optimized(model_name): """Check that our implementation (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = falcon_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_dropout_add_ln = True config.residual_in_fp32 = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map={"": device}, trust_remote_code=True ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True ) model_hf.eval() out_hf = model_hf.transformer(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward" # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough # memory to run the model in fp32. @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"]) def test_falcon_parallel_forward(model_name, world_size): from apex.transformer import parallel_state dtype = torch.float16 config = falcon_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = False config.fused_bias_fc = True config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_dropout_add_ln = False config.residual_in_fp32 = True if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() pretrained_state_dict = remap_state_dict_hf_falcon( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) out, _ = all_gather_raw(out, process_group=process_group) out = rearrange(out, "(b s) d -> b s d", b=batch_size) logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) del model parallel_state.destroy_model_parallel() if rank == 0: model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True ) model_hf.eval() out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device) logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf # Without device_map, the model is loaded on the CPU, which is very slow model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 2 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"]) def test_falcon_generation(model_name): """Check that our implementation (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = falcon_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_dropout_add_ln = True config.residual_in_fp32 = True tokenizer = AutoTokenizer.from_pretrained(model_name) eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map={"": device}, trust_remote_code=True ) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert torch.equal(logits_cg, logits) # torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation" # We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough # memory to run the model in fp32. @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"]) def test_falcon_parallel_generation(model_name, world_size): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ from apex.transformer import parallel_state dtype = torch.float16 config = falcon_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) config.use_flash_attn = False config.fused_bias_fc = True config.fused_mlp = False # We don't have fused MLP for "gelu" activation config.fused_dropout_add_ln = False config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) pretrained_state_dict = remap_state_dict_hf_falcon( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() print("Without CUDA graph") out = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") out_cg = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, cg=True, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) del model parallel_state.destroy_model_parallel() if rank == 0: model_hf = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() with torch.inference_mode(): out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) model_ref.eval() with torch.inference_mode(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert torch.equal(logits_cg, logits) ================================================ FILE: tests/models/test_gpt.py ================================================ import re import pytest import torch from einops import rearrange from flash_attn.models.gpt import ( GPTLMHeadModel, remap_state_dict_hf_gpt2, shard_state_dict_tp, combine_state_dicts_tp, ) from flash_attn.utils.generation import InferenceParams from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config, GPT2Tokenizer from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"]) # @pytest.mark.parametrize('model_name', ["gpt2"]) def test_gpt2_state_dict(model_name): config = GPT2Config.from_pretrained(model_name) pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config) state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"]) # @pytest.mark.parametrize('model_name', ["gpt2"]) def test_gpt2_non_optimized(model_name): """Check that our implementation of GPT2 (without any optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 config = GPT2Config.from_pretrained(model_name) model = GPTLMHeadModel.from_pretrained(model_name, config) model = model.cuda().to(dtype=dtype) model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 4 max_seqlen = 512 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) out = model.transformer(input_ids) out_hf = model_hf.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() logits = model(input_ids).logits logits_hf = model_hf(input_ids).logits logits_ref = model_ref(input_ids).logits print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"]) # @pytest.mark.parametrize('model_name', ["gpt2"]) def test_gpt2_optimized(model_name): """Check that our implementation of GPT2 (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 config = GPT2Config.from_pretrained(model_name) vocab_size_og = config.vocab_size config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 model = GPTLMHeadModel.from_pretrained(model_name, config) model = model.cuda().to(dtype=dtype) model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 4 max_seqlen = 512 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") input_ids = torch.randint( 0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) out = model.transformer(input_ids) out_hf = model_hf.transformer(input_ids).last_hidden_state out_ref = model_ref.transformer(input_ids).last_hidden_state print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() logits = model(input_ids).logits[..., :vocab_size_og] logits_hf = model_hf(input_ids).logits logits_ref = model_ref(input_ids).logits print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize("rotary", [False, True]) # @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize("model_name", ["gpt2"]) def test_gpt2_generation(model_name, rotary, optimized): """Check that our implementation of GPT2 generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 config = GPT2Config.from_pretrained(model_name) if rotary: config.n_positions = 0 config.rotary_emb_fraction = 0.5 config.rotary_emb_base = 24000 config.residual_in_fp32 = True if optimized: config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True # if not rotary, we load the weight from HF but ignore the position embeddings. # The model would be nonsense but it doesn't matter for the test. model = GPTLMHeadModel.from_pretrained( model_name, config, strict=not rotary, device=device, dtype=dtype ) model.eval() if not rotary: model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to( device=device ) model_ref.eval() model_hf.eval() torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( device=device ) max_length = 25 # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 # Slow generation for reference sequences = [] scores = [] cur_input_ids = input_ids with torch.inference_mode(): scores.append(model(cur_input_ids).logits[:, -1]) sequences.append(scores[-1].argmax(dim=-1)) for _ in range(input_ids.shape[1] + 1, max_length): cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1) scores.append(model(cur_input_ids).logits[:, -1]) sequences.append(scores[-1].argmax(dim=-1)) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) scores = tuple(scores) out = model.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) print(out.sequences) print(tokenizer.batch_decode(out.sequences.tolist())) if getattr(config, "use_flash_attn", False): out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) print(out_cg.sequences) assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) if not rotary: out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) out_ref = model_ref.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) print( f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) print( f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) print(tokenizer.batch_decode(out_ref.sequences.tolist())) assert torch.all(out.sequences == sequences) assert torch.allclose( torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol ) if not rotary: assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_hf.sequences) assert ( torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() < 3 * ( torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): out = model.generate( input_ids=input_ids, max_length=max_length, teacher_outputs=teacher_outputs, return_dict_in_generate=True, output_scores=True, enable_timing=True, **kwargs, ) return torch.stack(out.scores, dim=1) @pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) # @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) @pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"]) # @pytest.mark.parametrize('rotary', [None]) @pytest.mark.parametrize("model_name", ["gpt2"]) def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): """Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 config = GPT2Config.from_pretrained(model_name) config.n_positions = 16 * 1024 assert seqlen <= maxlen <= config.n_positions if rotary is not None: config.n_positions = 0 config.rotary_emb_dim = 32 config.rotary_emb_interleaved = rotary == "interleaved" config.residual_in_fp32 = True config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True model = GPTLMHeadModel(config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) batch_size = 1 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) teacher_outputs = torch.randint( 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device ) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) assert torch.equal(logits, logits_cg) # Try increasing batch size and seqlen, then decrease them to see if it's still correct batch_size = 3 maxlen += 30 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) teacher_outputs = torch.randint( 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device ) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) assert torch.equal(logits, logits_cg) batch_size = 2 maxlen -= 35 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) teacher_outputs = torch.randint( 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device ) logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) assert torch.equal(logits, logits_cg) @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize("optimized", [False]) @pytest.mark.parametrize("model_name", ["gpt2"]) def test_gpt2_multiple_token_generation(model_name, optimized): """Generation when we pass in multiple tokens at a time, not just one.""" dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 config = GPT2Config.from_pretrained(model_name) config.residual_in_fp32 = True if optimized: config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device) # Reference logits logits_ref = model(input_ids).logits # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits inference_params = InferenceParams(max_seqlen=20, max_batch_size=1) logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits inference_params.seqlen_offset += 10 position_ids = torch.arange(10, 14, dtype=torch.long, device=device) logits_1014 = model( input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params ).logits inference_params.seqlen_offset += 4 position_ids = torch.arange(14, 20, dtype=torch.long, device=device) logits_1420 = model( input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params ).logits logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1) print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("cg", [False, True]) # @pytest.mark.parametrize("cg", [True]) @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize("optimized", [True]) # @pytest.mark.parametrize("model_name", ["gpt2-medium"]) @pytest.mark.parametrize("model_name", ["gpt2-xl"]) def test_gpt2_speculative_decoding(model_name, optimized, cg): if cg and not optimized: pytest.skip() # CG requires use_flash_attn dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 config = GPT2Config.from_pretrained(model_name) config.residual_in_fp32 = True if optimized: config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config_draft = GPT2Config.from_pretrained("gpt2") config_draft.residual_in_fp32 = True if optimized: config_draft.use_flash_attn = True config_draft.fused_bias_fc = True config_draft.fused_mlp = True config_draft.fused_dropout_add_ln = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype) model_draft.eval() torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( device=device ) max_length = 100 from flash_attn.utils.generation import decode_speculative torch.manual_seed(42) print(f"Speculative decoding, {optimized = }") out = decode_speculative( input_ids, model, model_draft, max_length=max_length, top_k=5, cg=cg, speculative_lookahead=4, enable_timing=True, # debug=True, ) print(tokenizer.batch_decode(out.sequences)) print(f"Without speculative decoding, {cg = }") out_og = model.generate( input_ids, max_length=max_length, top_k=5, cg=cg, enable_timing=True, return_dict_in_generate=True, ) print(tokenizer.batch_decode(out_og.sequences)) @pytest.mark.parametrize( "n_heads_q_kv", [ (8, 8), # Regular attention (8, 4), # GQA (8, 2), # MQA ], ) def test_gpt2_shard_unshard(n_heads_q_kv): world_size = 2 config = GPT2Config.from_pretrained("gpt2") config.vocab_size = 1024 config.n_head, config.n_head_kv = n_heads_q_kv model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16) state_dict = model.state_dict() shards = [ # NOTE: Shallow copy as `state_dict` is modified in-place shard_state_dict_tp(dict(state_dict), config, world_size, rank) for rank in range(world_size) ] state_dict2 = combine_state_dicts_tp(shards, config) assert state_dict2.keys() == state_dict.keys() for k in state_dict.keys(): ref = state_dict[k] new = state_dict[k] assert torch.allclose(ref, new, atol=0.0, rtol=0.0) ================================================ FILE: tests/models/test_gpt_generation_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel" import os import re import pytest import torch from einops import rearrange from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2 from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config, GPT2Tokenizer from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF # @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize("rotary", [False]) @pytest.mark.parametrize("model_name", ["gpt2"]) def test_tensor_parallel(model_name, rotary, world_size): """Check that our implementation of GPT2 generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ dtype = torch.float16 rtol, atol = 3e-3, 3e-1 config = GPT2Config.from_pretrained(model_name) if rotary: config.n_positions = 0 config.rotary_emb_dim = 64 config.residual_in_fp32 = True config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) from apex.transformer import parallel_state parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() # if not rotary, we load the weight from HF but ignore the position embeddings. # The model would be nonsense but it doesn't matter for the test. model = GPTLMHeadModel.from_pretrained( model_name, config, strict=not rotary, device=device, dtype=dtype, process_group=process_group, world_size=world_size, rank=rank, ) model.eval() if not rotary: model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) model_ref.eval() model_hf.eval() torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to( device=device ) max_length = 30 # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 # Slow generation for reference sequences = [] scores = [] cur_input_ids = input_ids with torch.inference_mode(): logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[ ..., : config.vocab_size ] scores.append(logits) sequences.append(scores[-1].argmax(dim=-1)) for _ in range(input_ids.shape[1] + 1, max_length): cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1) logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[ ..., : config.vocab_size ] scores.append(logits) sequences.append(scores[-1].argmax(dim=-1)) sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) scores = tuple(scores) print(sequences) out = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) print(out.sequences) if getattr(config, "use_flash_attn", False): out_cg = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) print(out_cg.sequences) parallel_state.destroy_model_parallel() if not rotary: out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) out_ref = model_ref.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) print( f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) print( f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) assert torch.all(out.sequences == sequences) assert torch.allclose( torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol ) assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) if not rotary: assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_hf.sequences) assert ( torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() < 3 * ( torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() ================================================ FILE: tests/models/test_gpt_neox.py ================================================ # Copyright (c) 2023, Tri Dao. import time import pytest import torch from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import AutoTokenizer, GPTNeoXConfig from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"]) def test_gptj_state_dict(model_name): config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) pretrained_state_dict = remap_state_dict_hf_gpt_neox( state_dict_from_pretrained(model_name), config ) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize( "model_name", [ "EleutherAI/pythia-1b", "EleutherAI/pythia-2.8b", "EleutherAI/gpt-neox-20b", "togethercomputer/RedPajama-INCITE-7B-Base", ], ) def test_gpt_neox_optimized(model_name): """Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name)) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = config.activation_function in [ "gelu_fast", "gelu_new", "gelu_approx", "gelu_pytorch_tanh", ] config.fused_dropout_add_ln = True config.residual_in_fp32 = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Need at least 2 GPUs, otherwise we'll OOM for the 20B model # Without device_map, the model is loaded on the CPU, which is very slow model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto") model_ref.eval() with torch.no_grad(): out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = GPTNeoXForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device} ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.gpt_neox(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 2 * ( logits_hf - logits_ref ).abs().max().item() assert (logits - logits_ref).abs().mean().item() < 2 * ( logits_hf - logits_ref ).abs().mean().item() ================================================ FILE: tests/models/test_gpt_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py import math import pytest import torch import torch.nn as nn import torch.nn.functional as F from apex.transformer import parallel_state from einops import rearrange from flash_attn.losses.cross_entropy import CrossEntropyLoss from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp from flash_attn.utils.distributed import allreduce_sequence_parallel_grad from transformers import GPT2Config is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("has_pos_emb", [True, False]) # @pytest.mark.parametrize('has_pos_emb', [True]) @pytest.mark.parametrize("dim", [1024]) def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): head_dim = 64 assert dim % head_dim == 0 num_heads = dim // head_dim assert num_heads % world_size == 0 vocab_size = 50264 assert vocab_size % world_size == 0 num_layers = 2 rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device) # We need to generate g here so that all processes get the same gradient, # as rank 0 will have an extra bias that changes the RNG. g = torch.randn(batch_size * seqlen, device=device) config = GPT2Config( n_embd=dim, n_head=num_heads, n_layer=num_layers, n_positions=seqlen if has_pos_emb else 0, vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, scale_attn_by_inverse_layer_idx=True, use_flash_attn=True, fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, residual_in_fp32=True, rotary_emb_fraction=0.0 if has_pos_emb else 0.5, pad_vocab_size_multiple=8 * world_size, sequence_parallel=sequence_parallel, ) config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size) model_pt = GPTLMHeadModel(config, device=device) def init_layer_norm(module): if isinstance(module, nn.LayerNorm): nn.init.normal_(module.weight) nn.init.normal_(module.bias) model_pt.apply(init_layer_norm) model = GPTLMHeadModel(config, process_group=process_group, device=device) total_nparams = sum(p.numel() for p in model_pt.parameters()) sharded_nparams = sum(p.numel() for p in model.parameters()) sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) torch.distributed.all_gather_into_tensor( sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group ) shared_nparams = sum( p.numel() for p in model.parameters() if getattr(p, "_shared_params", False) ) shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device) torch.distributed.all_gather_into_tensor( shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group ) assert torch.all(shared_nparams_all == shared_nparams) assert total_nparams == ( (sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams ) # vocab_size has been rounded up here partition_vocab_size = config.vocab_size // world_size partition_dim = dim // world_size partition_hidden_dim = 4 * dim // world_size with torch.no_grad(): model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank)) model.tie_weights() with torch.autocast(device_type="cuda", dtype=dtype): out = model(input_ids[:, :-1]).logits if not sequence_parallel: out = rearrange(out, "b s d -> (b s) d") out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d") partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size], rtol=rtol, atol=atol, ) loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group) loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none") loss = loss_fn(out, input_ids[:, 1:].flatten()) loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten()) assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol) loss_pt.backward(g) loss.backward(g) allreduce_sequence_parallel_grad(model, process_group) parallel_state.destroy_model_parallel() grad_dict = shard_state_dict_tp( {k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank ) assert torch.allclose( model.transformer.embeddings.word_embeddings.weight.grad, grad_dict["transformer.embeddings.word_embeddings.weight"], rtol=rtol, atol=atol * 5, ) if has_pos_emb: assert torch.allclose( model.transformer.embeddings.position_embeddings.weight.grad, grad_dict["transformer.embeddings.position_embeddings.weight"], rtol=rtol, atol=atol, ) assert torch.allclose( model.transformer.ln_f.weight.grad, grad_dict["transformer.ln_f.weight"], rtol=rtol, atol=atol, ) assert torch.allclose( model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol ) for i in range(num_layers): assert torch.allclose( model.transformer.layers[i].mixer.Wqkv.weight.grad, grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.transformer.layers[i].mixer.Wqkv.bias.grad, grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.transformer.layers[i].mixer.out_proj.weight.grad, grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"], rtol=rtol, atol=atol * 10, ) if rank == 0: assert torch.allclose( model.transformer.layers[i].mixer.out_proj.bias.grad, grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"], rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.transformer.layers[i].mlp.fc1.weight.grad, grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.transformer.layers[i].mlp.fc1.bias.grad, grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.transformer.layers[i].mlp.fc2.weight.grad, grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"], rtol=rtol, atol=atol * 10, ) if rank == 0: assert torch.allclose( model.transformer.layers[i].mlp.fc2.bias.grad, grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"], rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.transformer.layers[i].norm1.weight.grad, grad_dict[f"transformer.layers.{i}.norm1.weight"], rtol=rtol, atol=atol, ) assert torch.allclose( model.transformer.layers[i].norm1.bias.grad, grad_dict[f"transformer.layers.{i}.norm1.bias"], rtol=rtol, atol=atol, ) assert torch.allclose( model.transformer.layers[i].norm2.weight.grad, grad_dict[f"transformer.layers.{i}.norm2.weight"], rtol=rtol, atol=atol, ) assert torch.allclose( model.transformer.layers[i].norm2.bias.grad, grad_dict[f"transformer.layers.{i}.norm2.bias"], rtol=rtol, atol=atol, ) ================================================ FILE: tests/models/test_gptj.py ================================================ # Copyright (c) 2023, Tri Dao. import time import pytest import torch from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import AutoTokenizer, GPTJConfig from transformers.models.gptj.modeling_gptj import GPTJForCausalLM @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) def test_gptj_state_dict(model_name): config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"]) def test_gptj_optimized(model_name): """Check that our implementation of GPT-J (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True config.residual_in_fp32 = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref.eval() with torch.no_grad(): out_ref = model_ref.transformer(input_ids).last_hidden_state logits_ref = model_ref(input_ids).logits del model_ref model_hf = GPTJForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device} ) model_hf.eval() out_hf = model_hf.transformer(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"]) def test_gptj_generation(model_name): """Check that our implementation of GPT-J (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) config.use_flash_attn = True # FlashAttention-2 supports headdim 256 config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = True tokenizer = AutoTokenizer.from_pretrained(model_name) eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = GPTJForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map={"": device} ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") assert torch.equal(logits_cg, logits) ================================================ FILE: tests/models/test_llama.py ================================================ # Copyright (c) 2023, Tri Dao. # To run the huggingface implementation of LLaMa (1), we first need to convert the weights: # https://github.com/huggingface/transformers/pull/21955 # python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf # and repeat for 13B, 30B, 65B import os import time from pathlib import Path current_dir = Path(__file__).parent.absolute() import shutil import pytest import torch from einops import rearrange from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.llama import ( config_from_checkpoint, inv_remap_state_dict_hf_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama, remap_state_dict_meta_llama, state_dicts_from_checkpoint, ) from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import LlamaConfig, LlamaTokenizer from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers import AutoConfig def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): if checkpoint_format == "meta": ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) else: pretrained_state_dict = state_dict_from_pretrained( Path(checkpoint_path) / f"{model_name}-hf" ) pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config) return pretrained_state_dict @pytest.mark.parametrize("model_name", ["7B"]) def test_llama_state_dict(model_name): checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape # TinyLlama-1.1B is to test MQA @pytest.mark.parametrize( "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"] ) def test_inv_remap_state_dict_hf_llama(model_name): config = llama_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) state_dict = state_dict_from_pretrained(model_name) # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) assert set(state_dict_recover.keys()) == set(state_dict.keys()) for key in state_dict_recover.keys(): torch.testing.assert_close(state_dict_recover[key], state_dict[key]) # TinyLlama-1.1B is to test MQA @pytest.mark.parametrize( "model_name", [ "7B", # Llama 1 "13B", # Llama 1 "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-7b-hf", "codellama/CodeLlama-13b-hf", "codellama/CodeLlama-34b-hf", "PY007/TinyLlama-1.1B-step-50K-105b", ], ) def test_llama_optimized(model_name): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) dtype = torch.float16 device = "cuda" if "/" in model_name: # Download from HF config = llama_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) else: config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True if "/" in model_name: # Download from HF pretrained_state_dict = remap_state_dict_hf_llama( state_dict_from_pretrained(model_name), config ) else: pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format="meta" ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits del model # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", device_map="auto", ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}, ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.model(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize( "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] ) def test_llama_parallel(model_name, world_size): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ from apex.transformer import parallel_state checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) dtype = torch.float16 if "/" in model_name: # Download from HF config = llama_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) else: config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() if "/" in model_name: # Download from HF pretrained_state_dict = remap_state_dict_hf_llama( state_dict_from_pretrained(model_name), config ) else: pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format="meta" ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) with torch.no_grad(): out = model.transformer(input_ids) out, _ = all_gather_raw(out, process_group=process_group) out = rearrange(out, "(b s) d -> b s d", b=batch_size) logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) del model if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_ref = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", device_map="auto", ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref model_hf = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto", ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 2 * ( logits_hf - logits_ref ).abs().max().item() # @pytest.mark.parametrize('model_name', ["7B", "13B"]) @pytest.mark.parametrize("model_name", ["7B"]) @pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) def test_llama_generation(model_name, checkpoint_format): checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) dtype = torch.float16 device = "cuda" config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf") eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) model_hf = LlamaForCausalLM.from_pretrained( Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = LlamaForCausalLM.from_pretrained( Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" ) model_ref.eval() with torch.no_grad(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, teacher_outputs=out_hf.sequences, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) del model hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}") assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize( "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] ) def test_llama_parallel_generation(model_name, world_size): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ from apex.transformer import parallel_state checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) dtype = torch.float16 if "/" in model_name: # Download from HF config = llama_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) else: config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device ) # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) if "/" in model_name: # Download from HF pretrained_state_dict = remap_state_dict_hf_llama( state_dict_from_pretrained(model_name), config ) else: pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format="meta" ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() print("Without CUDA graph") out = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") out_cg = model.generate( input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, cg=True, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) del model parallel_state.destroy_model_parallel() if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_hf = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto", ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() with torch.inference_mode(): out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = LlamaForCausalLM.from_pretrained( model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", device_map="auto", ) model_ref.eval() with torch.inference_mode(): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}") assert torch.equal(logits_cg, logits) @torch.no_grad() @pytest.mark.parametrize("world_size", [2]) def test_llama_parallel_uneven_num_heads(world_size): from apex.transformer import parallel_state checkpoint_path = ( Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" ) num_attention_heads = world_size + 1 model_name = f"teeny-{num_attention_heads}-heads" if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() dtype = torch.float16 llama_config = LlamaConfig( hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256 intermediate_size=256 * num_attention_heads * 4, num_hidden_layers=4, num_attention_heads=num_attention_heads, initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test. ) config = llama_config_to_gpt2_config(llama_config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) # Create a shared test model. if rank == 0: LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf") torch.distributed.barrier() # Run the standard forward pass test. pretrained_state_dict = _pretrained_state_dict_from_checkpoint( checkpoint_path, model_name, config, checkpoint_format="hf" ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs. out = model.transformer(input_ids) out, _ = all_gather_raw(out, process_group=process_group) out = rearrange(out, "(b s) d -> b s d", b=batch_size) logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) if rank == 0: model_ref = LlamaForCausalLM.from_pretrained( Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device} ) model_ref = model_ref.to(device=device) model_ref.eval() out_ref = model_ref.model(input_ids).last_hidden_state logits_ref = model_ref(input_ids).logits del model_ref model_hf = LlamaForCausalLM.from_pretrained( Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} ) model_hf.eval() out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 2 * ( logits_hf - logits_ref ).abs().max().item() if os.path.exists(checkpoint_path / f"{model_name}-hf"): shutil.rmtree(checkpoint_path / f"{model_name}-hf") ================================================ FILE: tests/models/test_opt.py ================================================ import re import time import pytest import torch from einops import rearrange from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.opt import opt_config_to_gpt2_config, remap_state_dict_hf_opt from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import AutoTokenizer, OPTConfig from transformers.models.opt.modeling_opt import OPTForCausalLM @pytest.mark.parametrize( "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"] ) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) def test_opt_state_dict(model_name): config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) pretrained_state_dict = remap_state_dict_hf_opt(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config) state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape @pytest.mark.parametrize( "model_name", ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"] ) # @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) def test_opt_optimized(model_name): """Check that our implementation of OPT (without all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ dtype = torch.float16 device = "cuda" config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = getattr(config, "prenorm", True) config.pad_vocab_size_multiple = 8 model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) model.eval() model_ref.eval() model_hf.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda") input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda" ) if model_name != "facebook/opt-350m": # The OPT-350m projects the embeddings to dimension 512 out = model.transformer(input_ids) out_hf = model_hf.model(input_ids).last_hidden_state out_ref = model_ref.model(input_ids).last_hidden_state print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() logits = model(input_ids).logits logits_hf = model_hf(input_ids).logits logits_ref = model_ref(input_ids).logits print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") assert (logits - logits_ref).abs().max().item() < 3 * ( logits_hf - logits_ref ).abs().max().item() @pytest.mark.parametrize( "model_name", [ "facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", ], ) # @pytest.mark.parametrize('model_name', ["facebook/opt-125m"]) def test_opt_generation(model_name): """Check that our implementation of OPT generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. """ print(f"\nMODEL: {model_name}") verbose = False dtype = torch.float16 device = "cuda" rtol, atol = 3e-3, 3e-1 config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = getattr(config, "prenorm", True) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = True config.fused_dropout_add_ln = True model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval() torch.manual_seed(0) # OPT tokenizer requires use_fast=False # https://huggingface.co/docs/transformers/model_doc/opt tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) eos_token_id = tokenizer.eos_token_id input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to( device=device ) max_length = 25 # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 # Slow generation for reference sequences = [] scores = [] cur_input_ids = input_ids with torch.inference_mode(): scores.append(model(cur_input_ids).logits[:, -1]) sequences.append(scores[-1].argmax(dim=-1)) for _ in range(input_ids.shape[1] + 1, max_length): cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1) scores.append(model(cur_input_ids).logits[:, -1]) sequences.append(scores[-1].argmax(dim=-1)) if eos_token_id is not None and (sequences[-1] == eos_token_id).all(): break sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) scores = tuple(scores) print("Without CUDA graph") torch.cuda.synchronize() start = time.time() out = model.generate( input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") if verbose: print(out.sequences) print(tokenizer.batch_decode(out.sequences.tolist())) if getattr(config, "use_flash_attn", False): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) print("With CUDA graph") torch.cuda.synchronize() start = time.time() out_cg = model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=True, ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") if verbose: print(out_cg.sequences) print(tokenizer.batch_decode(out_cg.sequences.tolist())) del model model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() out_hf = model_hf.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) model_ref.eval() print("HF fp32") torch.cuda.synchronize() start = time.time() out_ref = model_ref.generate( input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True ) torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_ref print(tokenizer.batch_decode(out_ref.sequences.tolist())) if verbose: print( f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) print( f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}" ) print( f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}" ) assert torch.all(out.sequences == sequences) assert torch.allclose( torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol ) assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_hf.sequences) assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * ( torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() ================================================ FILE: tests/models/test_vit.py ================================================ import re import pytest import torch from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224 from timm.models.vision_transformer import vit_base_patch16_224 @pytest.mark.parametrize("fused_mlp", [False, True]) # @pytest.mark.parametrize('fused_mlp', [False]) @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize('optimized', [True]) def test_vit(optimized, fused_mlp): """Check that our implementation of ViT matches the timm's implementation: the output of our forward pass in fp16 should be around the same as timm' forward pass in fp16, when compared to timm's forward pass in fp32. """ dtype = torch.float16 device = "cuda" kwargs = {} if optimized: kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True) kwargs["fused_mlp"] = fused_mlp model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype) model_ref = vit_base_patch16_224(pretrained=True).to(device=device) model_timm = vit_base_patch16_224(pretrained=True).to(device=device, dtype=dtype) model.load_state_dict(model_ref.state_dict()) model.eval() model_ref.eval() model_timm.eval() torch.manual_seed(0) batch_size = 2 x = torch.randn(batch_size, 3, 224, 224, device=device, dtype=dtype) out = model(x) out_timm = model_timm(x) out_ref = model_ref(x.float()) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}") print(f"timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}") rtol = 2 if not fused_mlp else 8 assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item() ================================================ FILE: tests/modules/test_block_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_block_parallel.py import math from functools import partial import pytest import torch import torch.nn as nn import torch.nn.functional as F from apex.transformer import parallel_state, tensor_parallel from einops import rearrange from flash_attn.modules.block import Block from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP from flash_attn.utils.distributed import allreduce_sequence_parallel_grad is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [True]) @pytest.mark.parametrize("dim", [1024]) def test_block_parallel(dim, sequence_parallel, world_size, dtype): head_dim = 64 assert dim % head_dim == 0 num_heads = dim // head_dim assert num_heads % world_size == 0 rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 2 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True) residual_pt = torch.randn(batch_size * seqlen, dim, device=device, requires_grad=True) # We need to generate g here so that all processes get the same gradient, # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 if sequence_parallel: x = ( tensor_parallel.scatter_to_sequence_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) residual = ( tensor_parallel.scatter_to_sequence_parallel_region(residual_pt) .detach() .clone() .requires_grad_() ) else: x = x_pt.detach().clone().requires_grad_() residual = residual_pt.detach().clone().requires_grad_() mixer_cls_pt = partial( MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, device=device, dtype=dtype, ) mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype) norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype) model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True) with torch.no_grad(): nn.init.normal_(model_pt.norm1.weight) nn.init.normal_(model_pt.norm1.bias) nn.init.normal_(model_pt.norm2.weight) nn.init.normal_(model_pt.norm2.bias) mixer_cls = partial( ParallelMHA, num_heads=num_heads, process_group=parallel_state.get_tensor_model_parallel_group(), rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) mlp_cls = partial( ParallelFusedMLP, hidden_features=4 * dim, process_group=parallel_state.get_tensor_model_parallel_group(), sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) model = Block( dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True, sequence_parallel=sequence_parallel, mark_shared_params=True, ) partition_dim = dim // world_size partition_hidden_dim = 4 * dim // world_size with torch.no_grad(): model.mixer.Wqkv.weight.copy_( rearrange( rearrange(model_pt.mixer.Wqkv.weight, "(three o) i -> three o i", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o i -> (three o) i", ) ) model.mixer.Wqkv.bias.copy_( rearrange( rearrange(model_pt.mixer.Wqkv.bias, "(three o) -> three o", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o -> (three o)", ) ) model.mixer.out_proj.weight.copy_( model_pt.mixer.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim] ) if rank == 0: model.mixer.out_proj.bias.copy_(model_pt.mixer.out_proj.bias) model.mlp.fc1.weight.copy_( model_pt.mlp.fc1.weight[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim] ) model.mlp.fc1.bias.copy_( model_pt.mlp.fc1.bias[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim] ) model.mlp.fc2.weight.copy_( model_pt.mlp.fc2.weight[ :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim ] ) if rank == 0: model.mlp.fc2.bias.copy_(model_pt.mlp.fc2.bias) model.norm1.weight.copy_(model_pt.norm1.weight) model.norm1.bias.copy_(model_pt.norm1.bias) model.norm2.weight.copy_(model_pt.norm2.weight) model.norm2.bias.copy_(model_pt.norm2.bias) mixer_kwargs = {"seqlen": seqlen} out, out_residual = model(x, residual, mixer_kwargs=mixer_kwargs) out_pt, out_residual_pt = model_pt( rearrange(x_pt, "(b s) d -> b s d", s=seqlen), rearrange(residual_pt, "(b s) d -> b s d", s=seqlen), ) out_pt, out_residual_pt = [rearrange(x, "b s d -> (b s) d") for x in [out_pt, out_residual_pt]] partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_pt, rtol=rtol, atol=atol, ) assert torch.allclose( out_residual, out_residual_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_residual_pt, rtol=rtol, atol=atol, ) (out_pt + 2 * out_residual_pt).backward(g) (out + 2 * out_residual).backward( g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g ) allreduce_sequence_parallel_grad(model, parallel_state.get_tensor_model_parallel_group()) parallel_state.destroy_model_parallel() assert torch.allclose( x.grad, x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol / 10, # magnitude of x.grad is quite small ) assert torch.allclose( residual.grad, residual_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else residual_pt.grad, rtol=rtol, atol=atol, ) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( model.mixer.Wqkv.weight.grad, rearrange( rearrange(model_pt.mixer.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o i -> (three o) i", ), rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.mixer.Wqkv.bias.grad, rearrange( rearrange(model_pt.mixer.Wqkv.bias.grad, "(three o) -> three o", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o -> (three o)", ), rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.mixer.out_proj.weight.grad, model_pt.mixer.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim], rtol=rtol, atol=atol * 10, ) if rank == 0: assert torch.allclose( model.mixer.out_proj.bias.grad, model_pt.mixer.out_proj.bias.grad, rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.mlp.fc1.weight.grad, model_pt.mlp.fc1.weight.grad[ rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim ], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.mlp.fc1.bias.grad, model_pt.mlp.fc1.bias.grad[rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim], rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.mlp.fc2.weight.grad, model_pt.mlp.fc2.weight.grad[ :, rank * partition_hidden_dim : (rank + 1) * partition_hidden_dim ], rtol=rtol, atol=atol * 10, ) if rank == 0: assert torch.allclose( model.mlp.fc2.bias.grad, model_pt.mlp.fc2.bias.grad, rtol=rtol, atol=atol * 5 ) assert torch.allclose( model.norm1.weight.grad, model_pt.norm1.weight.grad, rtol=rtol, atol=atol * 5 ) assert torch.allclose(model.norm1.bias.grad, model_pt.norm1.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose( model.norm2.weight.grad, model_pt.norm2.weight.grad, rtol=rtol, atol=atol * 5 ) assert torch.allclose(model.norm2.bias.grad, model_pt.norm2.bias.grad, rtol=rtol, atol=atol * 5) ================================================ FILE: tests/modules/test_embedding_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_embedding_parallel.py import pytest import torch import torch.nn as nn import torch.nn.functional as F from apex.transformer import parallel_state from einops import rearrange from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("has_pos_emb", [True, False]) # @pytest.mark.parametrize('has_pos_emb', [True]) @pytest.mark.parametrize("dim", [1024]) def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype): vocab_size = 50264 seqlen = 2048 assert vocab_size % world_size == 0 assert dim % world_size == 0 rtol, atol = (3e-3, 5e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 input_ids_pt = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) input_ids = input_ids_pt.detach().clone() model_pt = GPT2Embeddings( dim, vocab_size, seqlen if has_pos_emb else 0, device=device, dtype=dtype ) model = ParallelGPT2Embeddings( dim, vocab_size, seqlen if has_pos_emb else 0, parallel_state.get_tensor_model_parallel_group(), sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) partition_vocab_size = vocab_size // world_size partition_dim = dim // world_size with torch.no_grad(): model.word_embeddings.weight.copy_( model_pt.word_embeddings.weight[ rank * partition_vocab_size : (rank + 1) * partition_vocab_size ] ) if has_pos_emb: model.position_embeddings.weight.copy_( model_pt.position_embeddings.weight[ :, rank * partition_dim : (rank + 1) * partition_dim ] ) out = model(input_ids, combine_batch_seqlen_dim=True) out_pt = rearrange(model_pt(input_ids), "b s d -> (b s) d") partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_pt, rtol=rtol, atol=atol, ) g = torch.randn_like(out_pt) out_pt.backward(g) out.backward( g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g ) parallel_state.destroy_model_parallel() assert torch.allclose( model.word_embeddings.weight.grad, model_pt.word_embeddings.weight.grad[ rank * partition_vocab_size : (rank + 1) * partition_vocab_size ], rtol=rtol, atol=atol, ) if has_pos_emb: assert torch.allclose( model.position_embeddings.weight.grad, model_pt.position_embeddings.weight.grad[ :, rank * partition_dim : (rank + 1) * partition_dim ], rtol=rtol, atol=atol, ) ================================================ FILE: tests/modules/test_mha_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py import math import pytest import torch import torch.nn.functional as F from apex.transformer import parallel_state, tensor_parallel from einops import rearrange from flash_attn.modules.mha import MHA, ParallelMHA is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("head_dim", [64, 128]) # @pytest.mark.parametrize('head_dim', [64]) @pytest.mark.parametrize("embed_dim", [1024, 4096]) # @pytest.mark.parametrize('embed_dim', [1024]) def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size, dtype): assert embed_dim % head_dim == 0 num_heads = embed_dim // head_dim assert num_heads % world_size == 0 rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 2 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 x_pt = torch.randn( batch_size * seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True ) # We need to generate g here so that all processes get the same gradient, # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 if sequence_parallel: x = ( tensor_parallel.scatter_to_sequence_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) else: x = x_pt.detach().clone().requires_grad_() model_pt = MHA( embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, device=device, dtype=dtype, ) partition_dim = embed_dim // world_size model = ParallelMHA( embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) with torch.no_grad(): model.Wqkv.weight.copy_( rearrange( rearrange(model_pt.Wqkv.weight, "(three o) i -> three o i", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o i -> (three o) i", ) ) model.Wqkv.bias.copy_( rearrange( rearrange(model_pt.Wqkv.bias, "(three o) -> three o", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o -> (three o)", ) ) model.out_proj.weight.copy_( model_pt.out_proj.weight[:, rank * partition_dim : (rank + 1) * partition_dim] ) if rank == 0: model.out_proj.bias.copy_(model_pt.out_proj.bias) out = model(x, seqlen=seqlen) out_pt = rearrange(model_pt(rearrange(x_pt, "(b s) d -> b s d", s=seqlen)), "b s d -> (b s) d") partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_pt, rtol=rtol, atol=atol, ) out_pt.backward(g) out.backward( g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g ) parallel_state.destroy_model_parallel() assert torch.allclose( x.grad, x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol / 100, # magnitude of x.grad is quite small ) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( model.Wqkv.weight.grad, rearrange( rearrange(model_pt.Wqkv.weight.grad, "(three o) i -> three o i", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o i -> (three o) i", ), rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.Wqkv.bias.grad, rearrange( rearrange(model_pt.Wqkv.bias.grad, "(three o) -> three o", three=3)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "three o -> (three o)", ), rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.out_proj.weight.grad, model_pt.out_proj.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim], rtol=rtol, atol=atol * 10, ) if rank == 0: assert torch.allclose( model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5 ) ================================================ FILE: tests/modules/test_mlp_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mlp_parallel.py import pytest import torch import torch.nn.functional as F from apex.transformer import parallel_state, tensor_parallel from einops import rearrange from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("activation", [F.silu, F.sigmoid]) # @pytest.mark.parametrize('activation', [F.silu]) @pytest.mark.parametrize("dim", [1024, 4096]) # @pytest.mark.parametrize('dim', [1024]) def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dtype): rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 2 seqlen = 1024 assert (batch_size * seqlen) % world_size == 0 x_pt = torch.randn(batch_size * seqlen, dim, device=device, dtype=dtype, requires_grad=True) # We need to generate g here so that all processes get the same gradient, # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 if sequence_parallel: x = ( tensor_parallel.scatter_to_sequence_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) else: x = x_pt.detach().clone().requires_grad_() model_pt = GatedMlp(dim, activation=activation, device=device, dtype=dtype) partition_dim = model_pt.fc1.weight.shape[0] // 2 // world_size model = ParallelGatedMlp( dim, parallel_state.get_tensor_model_parallel_group(), activation=activation, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) with torch.no_grad(): model.fc1.weight.copy_( rearrange( rearrange(model_pt.fc1.weight, "(two o) i -> two o i", two=2)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "two o i -> (two o) i", ) ) model.fc1.bias.copy_( rearrange( rearrange(model_pt.fc1.bias, "(two o) -> two o", two=2)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "two o -> (two o)", ) ) model.fc2.weight.copy_( model_pt.fc2.weight[:, rank * partition_dim : (rank + 1) * partition_dim] ) if rank == 0: model.fc2.bias.copy_(model_pt.fc2.bias) out = model(x) out_pt = model_pt(x_pt) partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_pt, rtol=rtol, atol=atol, ) out_pt.backward(g) out.backward( g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g ) parallel_state.destroy_model_parallel() assert torch.allclose( x.grad, x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol, ) assert torch.allclose( model.fc1.weight.grad, rearrange( rearrange(model_pt.fc1.weight.grad, "(two o) i -> two o i", two=2)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "two o i -> (two o) i", ), rtol=rtol, atol=atol, ) assert torch.allclose( model.fc1.bias.grad, rearrange( rearrange(model_pt.fc1.bias.grad, "(two o) -> two o", two=2)[ :, rank * partition_dim : (rank + 1) * partition_dim ], "two o -> (two o)", ), rtol=rtol, atol=atol, ) assert torch.allclose( model.fc2.weight.grad, model_pt.fc2.weight.grad[:, rank * partition_dim : (rank + 1) * partition_dim], rtol=rtol, atol=atol, ) if rank == 0: assert torch.allclose(model.fc2.bias.grad, model_pt.fc2.bias.grad, rtol=rtol, atol=atol) ================================================ FILE: tests/ops/test_dropout_layer_norm.py ================================================ import math import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from flash_attn.ops.layer_norm import ( DropoutAddLayerNorm, dropout_add_layer_norm, dropout_add_layer_norm_parallel_residual, dropout_add_layer_norm_subset, ) from flash_attn.ops.rms_norm import ( DropoutAddRMSNorm, dropout_add_rms_norm, dropout_add_rms_norm_parallel_residual, dropout_add_rms_norm_subset, ) try: from apex.normalization import FusedRMSNorm from apex.normalization.fused_layer_norm import fused_rms_norm_affine except: FusedRMSNorm, fused_rms_norm_affine = None, None is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("is_rms_norm", [False, True]) @pytest.mark.parametrize("has_colscale", [True, False]) # @pytest.mark.parametrize('has_colscale', [False]) @pytest.mark.parametrize("has_rowscale", [True, False]) # @pytest.mark.parametrize('has_rowscale', [True]) @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize('has_residual', [False]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm, ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported if is_rms_norm and FusedRMSNorm is None: pytest.skip() # We need Apex's FusedRMSNorm to test layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_colscale: colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) colscale_pt = colscale.detach().clone().requires_grad_() colscale_ref = colscale.detach().clone().float().requires_grad_() else: colscale = None if has_residual: res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None if has_rowscale: rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) survival_rate = 0.87 rowscale = rowscale.bernoulli_(survival_rate) / survival_rate x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1") x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1") else: rowscale = None x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref if has_colscale: x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_ref = x0_scaled_ref * colscale_ref model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) if not is_rms_norm: torch.nn.init.normal_(model_pt.bias) model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32) model = our_layer_norm_cls(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) with torch.no_grad(): model.weight.copy_(model_pt.weight) model_ref.weight.copy_(model_pt.weight) if not is_rms_norm: model.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, dmask = our_layer_norm_func( x0, res, model.weight, model.bias, model.p, model.eps, rowscale=rowscale, layerscale=colscale, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) assert out.dtype == input_dtype print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") if has_residual: residual_pt = ( (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref else: residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_ref = model_ref(residual_ref) assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 g = torch.randn_like(out) / batch_size out_pt.backward(g) out.backward(g) out_ref.backward(g) assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 3 * ( model_pt.weight.grad - model_ref.weight.grad ).abs().max() + 3e-5 if not is_rms_norm: assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( model_pt.bias.grad - model_ref.bias.grad ).abs().max() + 3e-5 if has_colscale: assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( colscale_pt.grad - colscale_ref.grad ).abs().max() + 2e-4 @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) @pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported device = "cuda" # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) dropout_p = 0.37 # set seed torch.random.manual_seed(0) batch_size = 32 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) with torch.no_grad(): model.weight.copy_(model_pt.weight) model.bias.copy_(model_pt.bias) model_ref.weight.copy_(model_pt.weight) model_ref.bias.copy_(model_pt.bias) model_pt.eval() model.eval() model_ref.eval() out = model(x0, res) residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype) residual_ref = x0_ref + res_ref out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) out_ref = model_ref(residual_ref) assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 @pytest.mark.parametrize("is_rms_norm", [False, True]) @pytest.mark.parametrize("has_colscale", [True, False]) @pytest.mark.parametrize("has_rowscale", [True, False]) @pytest.mark.parametrize("has_residual", [True, False]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_rowscale', [False]) # @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_prenorm_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_rowscale, has_colscale, is_rms_norm, ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported if is_rms_norm and FusedRMSNorm is None: pytest.skip() # We need Apex's FusedRMSNorm to test layer_norm_cls = torch.nn.LayerNorm if not is_rms_norm else FusedRMSNorm our_layer_norm_cls = DropoutAddLayerNorm if not is_rms_norm else DropoutAddRMSNorm our_layer_norm_func = dropout_add_layer_norm if not is_rms_norm else dropout_add_rms_norm device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 2e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_colscale: colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) colscale_pt = colscale.detach().clone().requires_grad_() colscale_ref = colscale.detach().clone().float().requires_grad_() else: colscale = None if has_residual: res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None if has_rowscale: rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) survival_rate = 0.87 rowscale = rowscale.bernoulli_(survival_rate) / survival_rate x0_scaled_pt = x0_pt * rearrange(rowscale, "... -> ... 1") x0_scaled_ref = x0_ref * rearrange(rowscale, "... -> ... 1") else: rowscale = None x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref if has_colscale: x0_scaled_pt = x0_scaled_pt * colscale_pt x0_scaled_ref = x0_scaled_ref * colscale_ref model_pt = layer_norm_cls(hidden_size).to(device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) if not is_rms_norm: torch.nn.init.normal_(model_pt.bias) model_ref = layer_norm_cls(hidden_size).to(device=device, dtype=torch.float32) model = our_layer_norm_cls( hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype ) with torch.no_grad(): model.weight.copy_(model_pt.weight) model_ref.weight.copy_(model_pt.weight) if not is_rms_norm: model.bias.copy_(model_pt.bias) model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = our_layer_norm_func( x0, res, model.weight, model.bias, model.p, model.eps, rowscale=rowscale, layerscale=colscale, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") if has_residual: residual_pt = ( (x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + res_ref else: residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) out_ref = model_ref(residual_ref) assert out.dtype == input_dtype assert residual.dtype == residual_dtype assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * ( residual_pt - residual_ref ).abs().max() + 1e-4 g = torch.randn_like(out) / batch_size (out_pt * F.sigmoid(residual_pt)).backward(g) (out * F.sigmoid(residual)).backward(g) (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( model_pt.weight.grad - model_ref.weight.grad ).abs().max() + 2e-4 if not is_rms_norm: assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( model_pt.bias.grad - model_ref.bias.grad ).abs().max() + 2e-4 if has_colscale: assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( colscale_pt.grad - colscale_ref.grad ).abs().max() + 2e-4 @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) @pytest.mark.parametrize("hidden_size", [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported device = "cuda" # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) dropout_p = 0.37 # set seed torch.random.manual_seed(0) batch_size = 32 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) model = DropoutAddLayerNorm( hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype ) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) with torch.no_grad(): model.weight.copy_(model_pt.weight) model.bias.copy_(model_pt.bias) model_ref.weight.copy_(model_pt.weight) model_ref.bias.copy_(model_pt.bias) model_pt.eval() model.eval() model_ref.eval() out, residual = model(x0, res) residual_pt = (x0_pt.float() + res_pt.float()).to(dtype=residual_dtype) residual_ref = x0_ref + res_ref out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) out_ref = model_ref(residual_ref) assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * ( residual_pt - residual_ref ).abs().max() + 1e-4 @pytest.mark.parametrize("has_colscale", [True, False]) @pytest.mark.parametrize("has_residual", [True, False]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_subset_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 2e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 drop_path_rate = 0.4 drop_path_scale = 1 / (1 - drop_path_rate) def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync mask_batch = torch.rand(batch_size) < 1 - drop_path_rate numrows = (mask_batch).sum().item() * seqlen mask_batch = mask_batch.to(device=device, non_blocking=True) mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen) subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_( ~mask_batch_seqlen, 0 ) return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size) x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks( batch_size, seqlen, drop_path_rate, device ) out_mask_batch, out_numrows, out_subset = generate_droppath_masks( batch_size, seqlen, drop_path_rate, device ) x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_colscale: colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) colscale_pt = colscale.detach().clone().requires_grad_() colscale_ref = colscale.detach().clone().float().requires_grad_() else: colscale = None if has_residual: res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None if has_colscale: x0_scaled_pt = x0_pt * colscale_pt x0_scaled_ref = x0_ref * colscale_ref else: x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model = DropoutAddLayerNorm( hidden_size, prenorm=False, p=dropout_p, device=device, dtype=weight_dtype ) with torch.no_grad(): model.weight.copy_(model_pt.weight) model.bias.copy_(model_pt.bias) model_ref.weight.copy_(model_pt.weight) model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, dmask = dropout_add_layer_norm_subset( x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, out_numrows=out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") x0_scaled_pt = ( x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) * drop_path_scale ) x0_scaled_ref = ( x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) * drop_path_scale ) dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) dmask_expanded[x0_mask_batch] = dmask if has_residual: residual_pt = ( (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref else: residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] out_ref = model_ref(residual_ref)[out_mask_batch] assert out.dtype == input_dtype assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 g = torch.randn_like(out) / batch_size out_pt.backward(g) out.backward(g) out_ref.backward(g) assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[ x0_mask_batch ].abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( model_pt.weight.grad - model_ref.weight.grad ).abs().max() + 2e-4 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( model_pt.bias.grad - model_ref.bias.grad ).abs().max() + 2e-4 if has_colscale: assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( colscale_pt.grad - colscale_ref.grad ).abs().max() + 2e-4 @pytest.mark.parametrize("has_colscale", [True, False]) @pytest.mark.parametrize("has_residual", [True, False]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('has_colscale', [True]) # @pytest.mark.parametrize('has_residual', [True]) # @pytest.mark.parametrize('dropout_p', [0.0]) # @pytest.mark.parametrize('weight_dtype', [torch.float32]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_subset_prenorm_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_residual, has_colscale ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 2e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 drop_path_rate = 0.4 drop_path_scale = 1 / (1 - drop_path_rate) def generate_droppath_masks(batch_size, seqlen, drop_path_rate, device): # Do it on CPU so we can get the numrows (with .item()) without GPU-CPU sync mask_batch = torch.rand(batch_size) < 1 - drop_path_rate numrows = (mask_batch).sum().item() * seqlen mask_batch = mask_batch.to(device=device, non_blocking=True) mask_batch_seqlen = repeat(mask_batch, "b -> (b s)", s=seqlen) subset = torch.cumsum(mask_batch_seqlen, dim=0, dtype=torch.int32).masked_fill_( ~mask_batch_seqlen, 0 ) return mask_batch, numrows, rearrange(subset, "(b s) -> b s", b=batch_size) x0_mask_batch, x0_numrows, x0_subset = generate_droppath_masks( batch_size, seqlen, drop_path_rate, device ) out_mask_batch, out_numrows, out_subset = generate_droppath_masks( batch_size, seqlen, drop_path_rate, device ) x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone()[x0_mask_batch].requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_colscale: colscale = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) colscale_pt = colscale.detach().clone().requires_grad_() colscale_ref = colscale.detach().clone().float().requires_grad_() else: colscale = None if has_residual: res_pt = torch.randn_like(x0_pt, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None if has_colscale: x0_scaled_pt = x0_pt * colscale_pt x0_scaled_ref = x0_ref * colscale_ref else: x0_scaled_pt = x0_pt x0_scaled_ref = x0_ref model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) torch.nn.init.normal_(model_pt.weight) torch.nn.init.normal_(model_pt.bias) model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) model = DropoutAddLayerNorm( hidden_size, prenorm=True, p=dropout_p, device=device, dtype=weight_dtype ) with torch.no_grad(): model.weight.copy_(model_pt.weight) model.bias.copy_(model_pt.bias) model_ref.weight.copy_(model_pt.weight) model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = dropout_add_layer_norm_subset( x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, out_numrows=out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) print(f"Actual dropout fraction: {1 - dmask.float().mean().item()}") x0_scaled_pt = ( x0_scaled_pt.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) * drop_path_scale ) x0_scaled_ref = ( x0_scaled_ref.masked_fill(repeat(~x0_mask_batch, "b -> b s d", s=seqlen, d=hidden_size), 0) * drop_path_scale ) dmask_expanded = torch.zeros_like(x0_pt, dtype=torch.uint8) dmask_expanded[x0_mask_batch] = dmask if has_residual: residual_pt = ( (x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) + res_ref else: residual_pt = ((x0_scaled_pt.float() * dmask_expanded.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_scaled_ref * dmask_expanded.float()) / (1 - dropout_p) out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)[out_mask_batch] out_ref = model_ref(residual_ref)[out_mask_batch] assert out.dtype == input_dtype assert residual.dtype == residual_dtype assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * ( residual_pt - residual_ref ).abs().max() + 1e-4 g = torch.randn_like(out) / batch_size (out_pt * F.sigmoid(residual_pt[out_mask_batch]) + residual_pt.mean(0, keepdim=True)).backward( g ) (out * F.sigmoid(residual[out_mask_batch]) + residual.mean(0, keepdim=True)).backward(g) ( out_ref * F.sigmoid(residual_ref[out_mask_batch].to(dtype=residual_dtype)) + residual_ref.mean(0, keepdim=True) ).backward(g) assert (x0.grad - x0_ref.grad[x0_mask_batch]).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad)[ x0_mask_batch ].abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * ( model_pt.weight.grad - model_ref.weight.grad ).abs().max() + 2e-4 assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * ( model_pt.bias.grad - model_ref.bias.grad ).abs().max() + 2e-4 if has_colscale: assert (colscale.grad - colscale_ref.grad).abs().max() <= 2 * ( colscale_pt.grad - colscale_ref.grad ).abs().max() + 2e-4 @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize('is_rms_norm', [False]) @pytest.mark.parametrize("tied_norm", [False, True]) # @pytest.mark.parametrize('tied_norm', [False]) @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize('has_residual', [False]) @pytest.mark.parametrize("has_x1", [True, False]) # @pytest.mark.parametrize('has_x1', [True]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_parallel_residual_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_x1, has_residual, tied_norm, is_rms_norm, ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported if is_rms_norm and fused_rms_norm_affine is None: pytest.skip() # We need Apex's FusedRMSNorm to test our_layer_norm_func = ( dropout_add_layer_norm_parallel_residual if not is_rms_norm else dropout_add_rms_norm_parallel_residual ) device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_x1: x1_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x1 = x1_pt.detach().clone().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_() else: x1 = None if has_residual: res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias0 = ( torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm else None ) weight0_pt = weight0.detach().clone().requires_grad_() weight0_ref = weight0.detach().clone().float().requires_grad_() bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None if not tied_norm: weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias1 = ( torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm else None ) weight1_pt = weight1.detach().clone().requires_grad_() weight1_ref = weight1.detach().clone().float().requires_grad_() bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None else: weight1, bias1 = None, None epsilon = 1e-5 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out0, out1, dmask0, dmask1 = our_layer_norm_func( x0, x1, res, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) assert out0.dtype == input_dtype if not tied_norm: assert out1.dtype == input_dtype print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}") if has_residual: if has_x1: residual_pt = ( (x0_pt.float() * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = ( (x0_ref * dmask0.float()) / (1 - dropout_p) + (x1_ref * dmask1.float()) / (1 - dropout_p) ) + res_ref else: residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to( dtype=residual_dtype ) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref else: if has_x1: residual_pt = ( (x0_pt.float() * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) ).to(dtype=residual_dtype) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + ( x1_ref * dmask1.float() ) / (1 - dropout_p) else: residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) if not is_rms_norm: out0_pt = F.layer_norm( residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon ).to(dtype=input_dtype) out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) if not tied_norm: out1_pt = F.layer_norm( residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt, bias1_pt, eps=epsilon, ).to(dtype=input_dtype) out1_ref = F.layer_norm( residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon ) else: out0_pt = fused_rms_norm_affine( residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon ).to(dtype=input_dtype) out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) if not tied_norm: out1_pt = fused_rms_norm_affine( residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon ).to(dtype=input_dtype) out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 if not tied_norm: assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4 g0 = torch.randn_like(out0) / batch_size if tied_norm: out0.backward(g0) out0_pt.backward(g0) out0_ref.backward(g0) else: g1 = torch.randn_like(out1) / batch_size (out0 * g0 + out1 * g1).sum().backward() (out0_pt * g0 + out1_pt * g1).sum().backward() (out0_ref * g0 + out1_ref * g1).sum().backward() assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 if has_x1: assert (x1.grad - x1_ref.grad).abs().max() <= 4 * ( x1_pt.grad - x1_ref.grad ).abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * ( weight0_pt.grad - weight0_ref.grad ).abs().max() + 3e-5 if not is_rms_norm: assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * ( bias0_pt.grad - bias0_ref.grad ).abs().max() + 3e-5 if not tied_norm: assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * ( weight1_pt.grad - weight1_ref.grad ).abs().max() + 3e-5 if not is_rms_norm: assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * ( bias1_pt.grad - bias1_ref.grad ).abs().max() + 3e-5 @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize('is_rms_norm', [False]) @pytest.mark.parametrize("tied_norm", [False, True]) # @pytest.mark.parametrize('tied_norm', [False]) @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize('has_residual', [False]) @pytest.mark.parametrize("has_x1", [True, False]) # @pytest.mark.parametrize('has_x1', [True]) @pytest.mark.parametrize("dropout_p", [0.37, 0.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) @pytest.mark.parametrize("weight_dtype", [torch.float32, torch.float16]) # @pytest.mark.parametrize('weight_dtype', [torch.float16]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) @pytest.mark.parametrize( "hidden_size", [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144], ) # @pytest.mark.parametrize('hidden_size', [256]) def test_dropout_layer_norm_parallel_residual_prenorm_training( hidden_size, input_dtype, residual_dtype, weight_dtype, dropout_p, has_x1, has_residual, tied_norm, is_rms_norm, ): if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: pytest.skip() # Not supported if is_rms_norm and fused_rms_norm_affine is None: pytest.skip() # We need Apex's FusedRMSNorm to test our_layer_norm_func = ( dropout_add_layer_norm_parallel_residual if not is_rms_norm else dropout_add_rms_norm_parallel_residual ) device = "cuda" # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-3, 1e-4) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x0_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0 = x0_pt.detach().clone().requires_grad_() x0_ref = x0_pt.detach().clone().float().requires_grad_() if has_x1: x1_pt = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x1 = x1_pt.detach().clone().requires_grad_() x1_ref = x1_pt.detach().clone().float().requires_grad_() else: x1 = None if has_residual: res_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res = res_pt.detach().clone().requires_grad_() res_ref = res_pt.detach().clone().float().requires_grad_() else: res = None weight0 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias0 = ( torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm else None ) weight0_pt = weight0.detach().clone().requires_grad_() weight0_ref = weight0.detach().clone().float().requires_grad_() bias0_pt = bias0.detach().clone().requires_grad_() if bias0 is not None else None bias0_ref = bias0.detach().clone().float().requires_grad_() if bias0 is not None else None if not tied_norm: weight1 = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) bias1 = ( torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm else None ) weight1_pt = weight1.detach().clone().requires_grad_() weight1_ref = weight1.detach().clone().float().requires_grad_() bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None bias1_ref = bias1.detach().clone().float().requires_grad_() if bias1 is not None else None else: weight1, bias1 = None, None epsilon = 1e-5 residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out0, out1, residual, dmask0, dmask1 = our_layer_norm_func( x0, x1, res, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True, ) assert out0.dtype == input_dtype if not tied_norm: assert out1.dtype == input_dtype print(f"Actual dropout fraction: {1 - dmask0.float().mean().item()}") if has_residual: if has_x1: residual_pt = ( (x0_pt.float() * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) + res_pt.float() ).to(dtype=residual_dtype) residual_ref = ( (x0_ref * dmask0.float()) / (1 - dropout_p) + (x1_ref * dmask1.float()) / (1 - dropout_p) ) + res_ref else: residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p) + res_pt.float()).to( dtype=residual_dtype ) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + res_ref else: if has_x1: residual_pt = ( (x0_pt.float() * dmask0.float()) / (1 - dropout_p) + (x1_pt.float() * dmask1.float()) / (1 - dropout_p) ).to(dtype=residual_dtype) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) + ( x1_ref * dmask1.float() ) / (1 - dropout_p) else: residual_pt = ((x0_pt.float() * dmask0.float()) / (1 - dropout_p)).to( dtype=residual_dtype ) residual_ref = (x0_ref * dmask0.float()) / (1 - dropout_p) if not is_rms_norm: out0_pt = F.layer_norm( residual_pt.to(dtype=weight_dtype), (hidden_size,), weight0_pt, bias0_pt, eps=epsilon ).to(dtype=input_dtype) out0_ref = F.layer_norm(residual_ref, (hidden_size,), weight0_ref, bias0_ref, eps=epsilon) if not tied_norm: out1_pt = F.layer_norm( residual_pt.to(dtype=weight_dtype), (hidden_size,), weight1_pt, bias1_pt, eps=epsilon, ).to(dtype=input_dtype) out1_ref = F.layer_norm( residual_ref, (hidden_size,), weight1_ref, bias1_ref, eps=epsilon ) else: out0_pt = fused_rms_norm_affine( residual_pt.to(dtype=weight_dtype), weight0_pt, (hidden_size,), eps=epsilon ).to(dtype=input_dtype) out0_ref = fused_rms_norm_affine(residual_ref, weight0_ref, (hidden_size,), eps=epsilon) if not tied_norm: out1_pt = fused_rms_norm_affine( residual_pt.to(dtype=weight_dtype), weight1_pt, (hidden_size,), eps=epsilon ).to(dtype=input_dtype) out1_ref = fused_rms_norm_affine(residual_ref, weight1_ref, (hidden_size,), eps=epsilon) assert (out0 - out0_ref).abs().max() <= 4 * (out0_pt - out0_ref).abs().max() + 1e-4 if not tied_norm: assert (out1 - out1_ref).abs().max() <= 4 * (out1_pt - out1_ref).abs().max() + 1e-4 assert (residual - residual_ref).abs().max() <= 4 * ( residual_pt - residual_ref ).abs().max() + 1e-4 g0 = torch.randn_like(out0) / batch_size if tied_norm: (out0 * F.sigmoid(residual)).backward(g0) (out0_pt * F.sigmoid(residual_pt)).backward(g0) (out0_ref * F.sigmoid(residual_ref)).backward(g0) else: g1 = torch.randn_like(out1) / batch_size (out0 * F.sigmoid(residual) * g0 + out1 * g1).sum().backward() (out0_pt * F.sigmoid(residual_pt) * g0 + out1_pt * g1).sum().backward() (out0_ref * F.sigmoid(residual_ref) * g0 + out1_ref * g1).sum().backward() assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 if has_x1: assert (x1.grad - x1_ref.grad).abs().max() <= 4 * ( x1_pt.grad - x1_ref.grad ).abs().max() + 1e-4 if has_residual: assert (res.grad - res_ref.grad).abs().max() <= 4 * ( res_pt.grad - res_ref.grad ).abs().max() + 1e-4 assert (weight0.grad - weight0_ref.grad).abs().max() <= 3 * ( weight0_pt.grad - weight0_ref.grad ).abs().max() + 3e-5 if not is_rms_norm: assert (bias0.grad - bias0_ref.grad).abs().max() <= 2 * ( bias0_pt.grad - bias0_ref.grad ).abs().max() + 3e-5 if not tied_norm: assert (weight1.grad - weight1_ref.grad).abs().max() <= 3 * ( weight1_pt.grad - weight1_ref.grad ).abs().max() + 3e-5 if not is_rms_norm: assert (bias1.grad - bias1_ref.grad).abs().max() <= 2 * ( bias1_pt.grad - bias1_ref.grad ).abs().max() + 3e-5 def test_dropout_layer_norm_randomness(): hidden_size = 256 dtype = torch.float32 dropout_p = 0.1 device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x0 = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=dtype, requires_grad=True ) res = torch.randn_like(x0, dtype=dtype, requires_grad=True) model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=dtype) torch.random.manual_seed(42) _, dmask0 = dropout_add_layer_norm( x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True ) # Subsequent call should have a different dropout mask _, dmask1 = dropout_add_layer_norm( x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True ) torch.random.manual_seed(42) # Resetting the seed, should get the same dropout mask _, dmask2 = dropout_add_layer_norm( x0, res, model.weight, model.bias, model.p, model.eps, return_dropout_mask=True ) assert not torch.equal(dmask0, dmask1) assert torch.equal(dmask0, dmask2) ================================================ FILE: tests/ops/test_fused_dense.py ================================================ import math from functools import partial import pytest import torch import torch.nn.functional as F from einops import rearrange from flash_attn.ops.fused_dense import FusedDense, FusedMLP @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("return_residual", [False, True]) @pytest.mark.parametrize("has_bias", [True, False]) @pytest.mark.parametrize("out_features", [1024, 4096]) @pytest.mark.parametrize("in_features", [1024, 4096]) def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype): device = "cuda" rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x_pt = torch.randn( batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True ) x = x_pt.detach().clone().requires_grad_() model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) model = FusedDense( in_features, out_features, bias=has_bias, return_residual=return_residual, device=device, dtype=dtype, ) with torch.no_grad(): model.weight.copy_(model_pt.weight) if has_bias: model.bias.copy_(model_pt.bias) out_pt = model_pt(x_pt) if not return_residual: out = model(x) else: out, x_copy = model(x) x_copy = ( x_copy[..., :out_features] if out_features < in_features else F.pad(x_copy, (0, out_features - in_features)) ) x_pt_copy = ( x_pt[..., :out_features] if out_features < in_features else F.pad(x_pt, (0, out_features - in_features)) ) # Just add some random function of the residual out_pt = out_pt + F.gelu(x_pt_copy) out = out + F.gelu(x_copy) # with torch.no_grad(): # out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half() assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(out) / 32 out_pt.backward(g) out.backward(g) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10) if has_bias: assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("heuristic", ["auto", -1]) # @pytest.mark.parametrize('heuristic', ['auto']) @pytest.mark.parametrize("checkpoint_lvl", [0, 1, 2]) # @pytest.mark.parametrize('checkpoint_lvl', [1]) @pytest.mark.parametrize("return_residual", [False, True]) # @pytest.mark.parametrize('return_residual', [False]) @pytest.mark.parametrize("has_bias2", [True, False]) @pytest.mark.parametrize("has_bias1", [True, False]) # @pytest.mark.parametrize('has_bias2', [True]) # @pytest.mark.parametrize('has_bias1', [True]) @pytest.mark.parametrize("activation", ["gelu_approx", "relu"]) # @pytest.mark.parametrize('activation', ['relu']) @pytest.mark.parametrize("out_features", [1024, 4096]) @pytest.mark.parametrize("in_features", [1024, 4096]) # @pytest.mark.parametrize('out_features', [4096]) # @pytest.mark.parametrize('in_features', [1024]) def test_fused_mlp( in_features, out_features, activation, has_bias1, has_bias2, return_residual, checkpoint_lvl, heuristic, dtype, ): device = "cuda" rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 x_pt = torch.randn( batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True ) x = x_pt.detach().clone().requires_grad_() model_pt_fc1 = torch.nn.Linear( in_features, out_features, bias=has_bias1, device=device, dtype=dtype ) model_pt_fc2 = torch.nn.Linear( out_features, in_features, bias=has_bias2, device=device, dtype=dtype ) model = FusedMLP( in_features, out_features, in_features, activation=activation, bias1=has_bias1, bias2=has_bias2, return_residual=return_residual, checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, device=device, dtype=dtype, ) with torch.no_grad(): model.fc1.weight.copy_(model_pt_fc1.weight) if has_bias1: model.fc1.bias.copy_(model_pt_fc1.bias) model.fc2.weight.copy_(model_pt_fc2.weight) if has_bias2: model.fc2.bias.copy_(model_pt_fc2.bias) activation_fn = ( partial(F.gelu, approximate="tanh") if activation == "gelu_approx" else partial(F.relu, inplace=True) ) out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt))) if not return_residual: out = model(x) else: out, x_copy = model(x) # Just add some random function of the residual out_pt = out_pt + F.gelu(x_pt) out = out + F.gelu(x_copy) assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(out) / 32 out_pt.backward(g) out.backward(g) # The error for relu is higher still if activation == "relu": atol = 1e-1 if dtype == torch.bfloat16 else 5e-2 assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10 ) if has_bias1: assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) assert torch.allclose( model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10 ) if has_bias2: assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) ================================================ FILE: tests/ops/test_fused_dense_parallel.py ================================================ # Run test with: # torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py import math import pytest import torch import torch.nn.functional as F from apex.transformer import parallel_state, tensor_parallel from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, FusedMLP, ParallelFusedMLP is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("has_bias", [True, False]) # @pytest.mark.parametrize('has_bias', [False]) @pytest.mark.parametrize("out_features", [1024]) @pytest.mark.parametrize("in_features", [4096]) def test_fused_linear_bias( in_features, out_features, has_bias, sequence_parallel, world_size, dtype ): assert out_features % world_size == 0 rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 2 seqlen = 512 assert batch_size * seqlen % world_size == 0 x_pt = torch.randn( batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True ) if sequence_parallel: x = ( tensor_parallel.scatter_to_sequence_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) else: x = x_pt.detach().clone().requires_grad_() model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) partition_out_features = out_features // world_size model = ColumnParallelLinear( in_features, out_features, parallel_state.get_tensor_model_parallel_group(), bias=has_bias, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) with torch.no_grad(): model.weight.copy_( model_pt.weight[rank * partition_out_features : (rank + 1) * partition_out_features] ) if has_bias: model.bias.copy_( model_pt.bias[rank * partition_out_features : (rank + 1) * partition_out_features] ) out = model(x) out_pt = model_pt(x_pt) assert torch.allclose( out, out_pt[:, rank * partition_out_features : (rank + 1) * partition_out_features], rtol=rtol, atol=atol, ) # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(out_pt) / 32 out_pt.backward(g) out.backward(g[:, rank * partition_out_features : (rank + 1) * partition_out_features]) parallel_state.destroy_model_parallel() partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( x.grad, x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol, ) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( model.weight.grad, model_pt.weight.grad[rank * partition_out_features : (rank + 1) * partition_out_features], rtol=rtol, atol=atol * 10, ) if has_bias: assert torch.allclose( model.bias.grad, model_pt.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features], rtol=rtol, atol=atol * 5, ) @pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("world_size", [1, 2, 4, 8]) # @pytest.mark.parametrize('world_size', [2]) @pytest.mark.parametrize("sequence_parallel", [True, False]) # @pytest.mark.parametrize('sequence_parallel', [False]) @pytest.mark.parametrize("has_bias2", [True, False]) # @pytest.mark.parametrize('has_bias2', [True]) @pytest.mark.parametrize("out_features", [4096]) @pytest.mark.parametrize("in_features", [1024]) def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype): assert out_features % world_size == 0 rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl", init_method="env://") device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() # set seed torch.random.manual_seed(0) batch_size = 2 seqlen = 512 assert batch_size * seqlen % world_size == 0 x_pt = torch.randn( batch_size * seqlen, in_features, device=device, dtype=dtype, requires_grad=True ) # We need to generate g here so that all processes get the same gradient, # as rank 0 will have an extra bias that changes the RNG. # If we don't divide by batch_size, the gradient gets a bit too large. g = torch.randn_like(x_pt) / 32 if sequence_parallel: x = ( tensor_parallel.scatter_to_sequence_parallel_region(x_pt) .detach() .clone() .requires_grad_() ) else: x = x_pt.detach().clone().requires_grad_() model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt_fc2 = torch.nn.Linear( out_features, in_features, bias=has_bias2, device=device, dtype=dtype ) partition_out_features = out_features // world_size partition_in_features = in_features // world_size model = ParallelFusedMLP( in_features, out_features, in_features, process_group=parallel_state.get_tensor_model_parallel_group(), bias2=has_bias2 and rank == 0, sequence_parallel=sequence_parallel, device=device, dtype=dtype, ) with torch.no_grad(): model.fc1.weight.copy_( model_pt_fc1.weight[rank * partition_out_features : (rank + 1) * partition_out_features] ) model.fc1.bias.copy_( model_pt_fc1.bias[rank * partition_out_features : (rank + 1) * partition_out_features] ) model.fc2.weight.copy_( model_pt_fc2.weight[ :, rank * partition_out_features : (rank + 1) * partition_out_features ] ) if has_bias2 and rank == 0: model.fc2.bias.copy_(model_pt_fc2.bias) out = model(x) out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate="tanh")) partition_batch_dim = batch_size * seqlen // world_size assert torch.allclose( out, out_pt[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else out_pt, rtol=rtol, atol=atol, ) out_pt.backward(g) out.backward( g[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else g ) parallel_state.destroy_model_parallel() assert torch.allclose( x.grad, x_pt.grad[rank * partition_batch_dim : (rank + 1) * partition_batch_dim] if sequence_parallel else x_pt.grad, rtol=rtol, atol=atol, ) # The error for d_weight and d_bias is quite a bit higher assert torch.allclose( model.fc1.weight.grad, model_pt_fc1.weight.grad[ rank * partition_out_features : (rank + 1) * partition_out_features ], rtol=rtol, atol=atol * 10, ) assert torch.allclose( model.fc1.bias.grad, model_pt_fc1.bias.grad[rank * partition_out_features : (rank + 1) * partition_out_features], rtol=rtol, atol=atol * 5, ) assert torch.allclose( model.fc2.weight.grad, model_pt_fc2.weight.grad[ :, rank * partition_out_features : (rank + 1) * partition_out_features ], rtol=rtol, atol=atol * 10, ) if has_bias2 and rank == 0: assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) ================================================ FILE: tests/ops/triton/test_layer_norm.py ================================================ # Copyright (c) 2024, Tri Dao. import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from flash_attn.ops.triton.layer_norm import ( layer_norm_fn, layer_norm_ref, rms_norm_ref, layer_norm_linear_fn, ) is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 # @pytest.mark.parametrize("zero_centered_weight", [False, True]) @pytest.mark.parametrize("zero_centered_weight", [False]) @pytest.mark.parametrize("has_weight1", [False, True]) # @pytest.mark.parametrize("has_weight1", [False]) @pytest.mark.parametrize("has_x1", [False, True]) # @pytest.mark.parametrize("has_x1", [False]) @pytest.mark.parametrize("has_rowscale", [False, True]) # @pytest.mark.parametrize("has_rowscale", [False]) @pytest.mark.parametrize("dropout_p", [0.0, 0.27]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("prenorm", [True, False]) # @pytest.mark.parametrize("prenorm", [True]) @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize("is_rms_norm", [True]) @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize("has_residual", [True]) @pytest.mark.parametrize( "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else []) ) # @pytest.mark.parametrize("weight_dtype", [torch.float32]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)]) @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096]) # @pytest.mark.parametrize("hidden_size", [1024]) def test_layer_norm( hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm, dropout_p, has_rowscale, has_x1, has_weight1, zero_centered_weight, ): if has_rowscale and has_x1: pytest.skip("Not supported") device = "cuda" if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): atol = 5e-2 elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]): atol = 1e-2 else: atol = 1e-4 # set seed torch.random.manual_seed(0) batch_size = 8 seqlen = 512 layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref allclose = ( # Sometimes x0_pt.grad is NaN lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol or ( # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit # by multiply and divide by 0.3 (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0 and (x - x_ref).abs().max() <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol ) ) x0 = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0_pt = x0.detach().clone().requires_grad_() x0_ref = x0.detach().clone().requires_grad_() if has_residual: res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res_pt = res.detach().clone().requires_grad_() res_ref = res.detach().clone().requires_grad_() else: res, res_pt, res_ref = None, None, None weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm: bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) else: bias = None weight_pt = weight.detach().clone().requires_grad_() weight_ref = weight.detach().clone().requires_grad_() bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None if has_x1: x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True) x1_pt = x1.detach().clone().requires_grad_() x1_ref = x1.detach().clone().requires_grad_() else: x1, x1_pt, x1_ref = None, None, None if has_weight1: weight1 = torch.randn( hidden_size, device=device, dtype=weight_dtype, requires_grad=True ) weight1_pt = weight1.detach().clone().requires_grad_() weight1_ref = weight1.detach().clone().requires_grad_() if not is_rms_norm: bias1 = torch.randn( hidden_size, device=device, dtype=weight_dtype, requires_grad=True ) else: bias1 = None bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None else: weight1, weight1_pt, weight1_ref = None, None, None bias1, bias1_pt, bias1_ref = None, None, None rowscale = ( torch.randn(batch_size, seqlen, dtype=input_dtype, device=device) if has_rowscale else None ) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, *rest = layer_norm_fn( x0, weight, bias, residual=res, x1=x1, weight1=weight1, bias1=bias1, eps=1e-6, dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, residual_in_fp32=residual_in_fp32, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=True, ) dropout_mask = rest[-2] if dropout_p > 0.0 else None dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None out_pt = layer_norm_ref_fn( x0_pt, weight_pt, bias_pt, residual=res_pt, x1=x1_pt, weight1=weight1_pt, bias1=bias1_pt, eps=1e-6, dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, ) out_ref = layer_norm_ref_fn( x0_ref, weight_ref, bias_ref, residual=res_ref, x1=x1_ref, weight1=weight1_ref, bias1=bias1_ref, eps=1e-6, dropout_p=dropout_p, rowscale=rowscale, prenorm=prenorm, zero_centered_weight=zero_centered_weight, dropout_mask=dropout_mask, dropout_mask1=dropout_mask1, upcast=True, ) if not has_weight1: if prenorm: residual = rest[0] out_pt, residual_pt = out_pt out_ref, residual_ref = out_ref out1, out1_pt, out1_ref = None, None, None else: out1 = rest.pop(0) if prenorm: residual = rest[0] out_pt, out1_pt, residual_pt = out_pt out_ref, out1_ref, residual_ref = out_ref else: out_pt, out1_pt = out_pt out_ref, out1_ref = out_ref assert out.dtype == input_dtype if prenorm: assert residual.dtype == residual_dtype assert allclose(residual, residual_pt, residual_ref) assert allclose(out, out_pt, out_ref) if out1 is not None: assert out1.dtype == input_dtype assert allclose(out1, out1_pt, out1_ref) if dropout_mask is not None: dropout_fraction = 1.0 - dropout_mask.float().mean() assert abs(dropout_fraction - dropout_p) < 0.01 if dropout_mask1 is not None: dropout_fraction = 1.0 - dropout_mask1.float().mean() assert abs(dropout_fraction - dropout_p) < 0.01 assert not torch.equal(dropout_mask, dropout_mask1) g = torch.randn_like(out) / batch_size if has_weight1: out = out * F.gelu(out1) out_pt = out_pt * F.gelu(out1_pt) out_ref = out_ref * F.gelu(out1_ref) if not prenorm: out.backward(g) out_pt.backward(g) out_ref.backward(g) else: (out * F.sigmoid(residual)).backward(g) (out_pt * F.sigmoid(residual_pt)).backward(g) (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) assert allclose(x0.grad, x0_pt.grad, x0_ref.grad) if has_residual: assert allclose(res.grad, res_pt.grad, res_ref.grad) if has_x1: assert allclose(x1.grad, x1_pt.grad, x1_ref.grad) assert allclose(weight.grad, weight_pt.grad, weight_ref.grad) if bias is not None: assert allclose(bias.grad, bias_pt.grad, bias_ref.grad) if has_weight1: assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad) if bias1 is not None: assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad) @pytest.mark.parametrize("prenorm", [True, False]) # @pytest.mark.parametrize("prenorm", [True]) @pytest.mark.parametrize("is_rms_norm", [False, True]) # @pytest.mark.parametrize("is_rms_norm", [True]) @pytest.mark.parametrize("has_residual", [True, False]) # @pytest.mark.parametrize("has_residual", [False]) @pytest.mark.parametrize("weight_dtype", [torch.float32]) @pytest.mark.parametrize( "input_dtype,residual_dtype", [(torch.float16, torch.float16), (torch.float16, torch.float32)] + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []), ) # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)]) @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000]) # @pytest.mark.parametrize("hidden_size", [256]) def test_layer_norm_linear( hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm ): device = "cuda" if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]): atol = 5e-2 elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]): atol = 1e-2 else: atol = 1e-4 # set seed torch.random.manual_seed(0) batch_size = 4 seqlen = 512 # batch_size = 1 # seqlen = 1 layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref allclose = ( lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max() <= 2 * (x_pt - x_ref).abs().max() + atol ) x0 = torch.randn( batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True ) x0_pt = x0.detach().clone().requires_grad_() x0_ref = x0.detach().clone().requires_grad_() if has_residual: res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) res_pt = res.detach().clone().requires_grad_() res_ref = res.detach().clone().requires_grad_() else: res, res_pt, res_ref = None, None, None norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) if not is_rms_norm: norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True) else: norm_bias = None norm_weight_pt = norm_weight.detach().clone().requires_grad_() norm_weight_ref = norm_weight.detach().clone().requires_grad_() norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None linear_weight = torch.empty( 2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True ) torch.nn.init.xavier_uniform_(linear_weight) if not is_rms_norm: linear_bias = torch.randn( 2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True ) else: linear_bias = None linear_weight_pt = linear_weight.detach().clone().requires_grad_() linear_weight_ref = linear_weight.detach().clone().requires_grad_() linear_bias_pt = ( linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None ) linear_bias_ref = ( linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None ) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 with torch.autocast(device_type="cuda", dtype=input_dtype): out, *rest = layer_norm_linear_fn( x0, norm_weight, norm_bias, linear_weight, linear_bias, residual=res, eps=1e-6, prenorm=prenorm, residual_in_fp32=residual_in_fp32, is_rms_norm=is_rms_norm, ) out_pt, *rest_pt = layer_norm_ref_fn( x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm ) with torch.autocast(device_type="cuda", dtype=input_dtype): out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt) out_ref, *rest_ref = layer_norm_ref_fn( x0_ref, norm_weight_ref, norm_bias_ref, residual=res_ref, eps=1e-6, prenorm=prenorm, upcast=True, ) out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref) if prenorm: residual = rest[0] residual_pt = rest_pt[0] residual_ref = rest_ref[0] assert out.dtype == input_dtype if prenorm: assert residual.dtype == residual_dtype assert allclose(residual, residual_pt, residual_ref) assert allclose(out, out_pt, out_ref) g = torch.randn_like(out) / batch_size out.backward(g) out_pt.backward(g) out_ref.backward(g) assert allclose(x0.grad, x0_pt.grad, x0_ref.grad) if has_residual: assert allclose(res.grad, res_pt.grad, res_ref.grad) assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad) if norm_bias is not None: assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad) assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad) if linear_bias is not None: assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad) ================================================ FILE: tests/pyproject.toml ================================================ [tool.black] line-length = 100 target-version = ['py38'] ================================================ FILE: tests/test_flash_attn.py ================================================ import math import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from flash_attn import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb MAX_HEADDIM_SM8x = 192 is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None ): batch, nheads = slopes.shape device = slopes.device slopes = rearrange(slopes, "b h -> b h 1 1") if causal: return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes else: row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) relative_pos = torch.abs(row_idx + sk - sq - col_idx) return -slopes * relative_pos.to(dtype=slopes.dtype) def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(-1, -1), # -1 means infinite window size query_padding_mask=None, key_padding_mask=None, device=None, key_leftpad=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), col_idx < row_idx + sk - sq - window_size[0], ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if softcap > 0: scores = scores / softcap scores = scores.tanh() scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_kvpacked_ref( q, kv, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): return attention_ref( q, kv[:, :, 0], kv[:, :, 1], query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, upcast=upcast, causal=causal, window_size=window_size, softcap=softcap, reorder_ops=reorder_ops, key_leftpad=key_leftpad, ) def attention_qkvpacked_ref( qkv, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, ): return attention_ref( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, upcast=upcast, causal=causal, window_size=window_size, softcap=softcap, reorder_ops=reorder_ops, ) def generate_sparsity_mask(seqlen, sparsity=0.3): repeats = seqlen // 16 // 2 # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) nrow, ncol = seqlen // 16, seqlen // 256 mask = torch.rand(nrow, ncol, device="cuda") < sparsity return mask def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) blockmask: (seqlen / 16, seqlen / 256) attn_mask: (batch_size, seqlen) dropout_p: float dropout_mask: (batch_size, nheads, seqlen, seqlen) Output: output: (batch_size, seqlen, nheads, head_dim) attention: softmax after dropout """ q, k, v = qkv.float().unbind(dim=2) d = qkv.shape[-1] seqlen = qkv.shape[1] scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") blockmask = blockmask[:seqlen, :seqlen] scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) attention = torch.softmax(scores, dim=-1) attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) def convert_flash_attn_S_to_softmax( S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """FlashAttention stores the S matrix in a different way. Arguments: S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) query_padding_mask: (batch_size, seqlen_q_rounded) key_padding_mask: (batch_size, seqlen_k_rounded) """ if causal: window_size = (window_size[0], 0) seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] S_converted = S if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, S.device, ) local_mask = F.pad( local_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True, ) S_converted = S_converted.masked_fill(local_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. seqlen_q_og = ( query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded ) if query_padding_mask is not None: query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k if key_padding_mask is not None: key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] def normalize_flash_attn_S( attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, is_dropout=False, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k, v: (batch_size, seqlen_k, nheads, head_dim) key_padding_mask: (batch_size, seqlen_q) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) Output: softmax_lse: (batch_size, nheads, seqlen_q) softmax_max: (batch_size, nheads, seqlen_q) """ if causal: window_size = (window_size[0], 0) q, k, v = q.float(), k.float(), v.float() _, seqlen_q, _, head_dim = q.shape seqlen_k = k.shape[1] scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. lse[lse == float("-inf")] = float("inf") scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_norm = torch.cat( [ a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") for a, m in zip(attn_unnorm_block, cummax_block) ], dim=-1, ) if query_padding_mask is not None: attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return attn_norm.to(dtype=attn_unnorm.dtype) def get_dropout_fraction( dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """ dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) """ if causal: window_size = (window_size[0], 0) batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape dropped = ~dropout_mask valid = torch.ones_like(dropout_mask) if query_padding_mask is not None: dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) if key_padding_mask is not None: dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, dropout_mask.device, ) dropped.masked_fill_(local_mask, False) valid.masked_fill_(local_mask, False) dropped_total = dropped.sum() return dropped.sum() / valid.sum() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize("seqlen", [512]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], None, None, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # v = qkv[:, :, 2].float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() # if causal: # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) # qk.masked_fill_(causal_mask, float('-inf')) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # p_tmp = torch.softmax(qk / math.sqrt(d), -1) # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:]) # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:]) # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:]) # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :]) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) # do_o = (g.float() * out.float()).sum(-1) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): (dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( qkv_unpad, cu_seqlens, max_seqlen, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, key_padding_mask, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize("kvpacked", [False]) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if softcap > 0: # Ensure the values of qk are at least within softcap range. q = q * softcap if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, kv, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: out, lse, S_dmask = flash_attn_func( q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) attn = normalize_flash_attn_S( attn_unnorm, q, k_rep, v_rep, None, None, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): if kvpacked: ( dq, dkv, ) = torch.autograd.grad(out, (q, kv), g) dk, dv = dkv.unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize('kvpacked', [False]) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 147), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0, 50.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if softcap > 0: # Ensure the values of qk are at least within softcap range. q = q * softcap if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None if kvpacked: ( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, output_pad_fn, dq_pad_fn, dkv_pad_fn, ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) attn = normalize_flash_attn_S( attn_unnorm, q, k_rep, v_rep, query_padding_mask, key_padding_mask, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, query_padding_mask, key_padding_mask, causal=causal, window_size=window_size, ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): if kvpacked: ( dq_unpad, dkv_unpad, ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) out_ref, attn_ref = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if paged_kv_block_size is None: k = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) block_table = None else: k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad = flash_attn_varlen_func( q_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged, v_unpad if paged_kv_block_size is None else v_cache_paged, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, block_table=block_table, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) test_backward = block_table is None if test_backward: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 if test_backward: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (3, 1024), (1, 339), (64, 800), (3, 799), (64, 2048), (16, 20000), (16, 100000), (128, 128), (256, 256), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_splitkv( seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype ): if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 1 nheads = 12 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None out, lse, _ = flash_attn_func( q, k, v, 0.0, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_ref( q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 mult = 2 if not alibi else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [True]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), (1, 128 * 1024), (16, 128 * 1024), (128, 128), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, has_batch_idx, has_leftpad, paged_kv_block_size, rotary_fraction, rotary_interleaved, seqlen_new_eq_seqlen_q, causal, local, alibi, new_kv, mha_type, num_splits, dtype, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if has_batch_idx and paged_kv_block_size is not None: pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) if has_leftpad: cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size)]) else: cache_leftpad = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size ] else: cache_batch_idx = None if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad ) else: alibi_slopes, attn_bias = None, None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = ( torch.rand( seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() v_cache_ref = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new ) k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out = flash_attn_with_kvcache( q, k_cache if paged_kv_block_size is None else k_cache_paged, v_cache if paged_kv_block_size is None else v_cache_paged, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, block_table=block_table, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, num_splits=num_splits, ) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if paged_kv_block_size is None: k_cache_select = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ) v_cache_select = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ) else: k_cache_select = rearrange( k_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache_select = rearrange( v_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache_select, v_cache_ref) mult = 3 if not alibi else 5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype ) v_cache_paged = torch.randn( num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype ) block_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache = rearrange( # pytorch 1.12 doesn't have indexing with int32 k_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache = rearrange( v_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (239, 1), (3, 799), (799, 3), (1024, 128), (97, 97), (128, 128), (200, 200), (256, 256), (257, 257), (384, 384), (512, 512), (768, 768), (1024, 1024), ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) g = torch.randn_like(out0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq0, dk0, dv0, ) = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() for i in range(250): torch.random.manual_seed(42) out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) assert torch.equal(out, out0) assert torch.equal(lse, lse0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) # @pytest.mark.parametrize('seqlen', [2]) def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 5 q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2) ] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out) out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( k_pt.grad - k_ref.grad ).abs().max().item() + 1e-3 assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( v_pt.grad - v_ref.grad ).abs().max().item() + 1e-3 @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) # @pytest.mark.parametrize('seqlen', [128]) def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 2 q, k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) for _ in range(3) ] out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") # So g is not contiguous g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt = rearrange(out_pt, "b s ... -> s b ...") out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( q_pt.grad - q_ref.grad ).abs().max().item() assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( k_pt.grad - k_ref.grad ).abs().max().item() assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( v_pt.grad - v_ref.grad ).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ device = "cuda" # set seed torch.random.manual_seed(0) nheads = 5 q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) Mq = 256 Mk = 3 q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) g = torch.randn_like(out) out.backward(g) assert not q.grad.isnan().any() assert not k.grad.isnan().any() assert not v.grad.isnan().any() @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, deterministic=True, ) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) ================================================ FILE: tests/test_flash_attn_ck.py ================================================ import math import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from flash_attn import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) from test_flash_attn import ( attn_bias_from_alibi_slopes, convert_flash_attn_S_to_softmax, generate_qkv, generate_random_padding_mask, _generate_block_kvcache, attention_ref, attention_kvpacked_ref, attention_qkvpacked_ref, ) from flash_attn.layers.rotary import apply_rotary_emb def is_bwd_hdim_supported(d): return d <= 256 def ck_randval_to_dropout_mask(randval, p): # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout # randval in 255 * [0, 0.7] will be kept # If return dropout_mask >=0, value will be kept return math.floor(255.0 * (1 - p)) - randval.to(torch.float32) def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_rounded, seqlen_k_rounded): """ pad + rearrange [nheads, total_q, max_seqlen_k] into [b, nheads, seqlen_q_rounded, seqlen_k_rounded] Arguments: S_dmask: (nheads, total_q, max_seqlen_k) cu_seqlens_q: (b + 1) Output: S_dmask: (b, nheads, seqlen_q_rounded, seqlen_k_rounded) """ batch_size = cu_seqlens_q.numel() - 1 seqlens_q = torch.roll(cu_seqlens_q, shifts = -1) - cu_seqlens_q seqlens_q = seqlens_q[0:batch_size].tolist() S_dmask = torch.split(S_dmask, seqlens_q, dim=1) # [(nheads, seqlen_q0, max_seqlen_k), (nheads, seqlen_q1, max_seqlen_k), ..., (nheads, seqlen_qb, max_seqlen_k)] masks = () for mask in S_dmask: # (nheads, seqlen_qi, max_seqlen_k) -> (nheads, seqlen_q_rounded, seqlen_k_rounded) mask = F.pad(mask, (0, seqlen_k_rounded - mask.shape[2], 0, seqlen_q_rounded - mask.shape[1], 0, 0)).unsqueeze(1) masks = masks + (mask, ) S_dmask = torch.cat(masks, dim=1) S_dmask = S_dmask.transpose(0, 1) return S_dmask @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if d > 256: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: # TODO - move to c++ mha_varlen_fwd() S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 # CK does not return P. Hence, we don't test the attn here. else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) if is_bwd_hdim_supported(d): (dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) @pytest.mark.parametrize("dropout_p", [0, 0.17]) def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if d > 256: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( qkv_unpad, cu_seqlens, max_seqlen, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: # TODO - move to c++ mha_varlen_fwd() S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens, seqlen, seqlen) S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 # CK does not return P. Hence, we don't test the attn here. else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) if is_bwd_hdim_supported(d): (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dqkv - dqkv_ref).abs().max().item() <= 10 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked ): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, kv, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: out, lse, S_dmask = flash_attn_func( q, k, v, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: # TODO - move to c++ mha_varlen_fwd() S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) # CK does not return P. Hence, we don't test the attn here. else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) if is_bwd_hdim_supported(d): if kvpacked: ( dq, dkv, ) = torch.autograd.grad(out, (q, kv), g) dk, dv = dkv.unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 147), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked ): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None if kvpacked: ( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, output_pad_fn, dq_pad_fn, dkv_pad_fn, ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: # TODO - move to c++ mha_varlen_fwd() S_dmask = ck_randval_to_dropout_mask(S_dmask, dropout_p) S_dmask = pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q, seqlen_k) S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) # CK does not return P. Hence, we don't test the attn here. else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most 4 times the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) if is_bwd_hdim_supported(d): if kvpacked: ( dq_unpad, dkv_unpad, ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ # (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if max(seqlen_q, seqlen_k) >= 2048: pytest.skip() if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) out_ref, attn_ref = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most 4 times the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() + 1e-5 g = torch.randn_like(out) if is_bwd_hdim_supported(d): do_o = (g.float() * out.float()).sum(-1) ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-4 assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-4 assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-4 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ # (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype ): if max(seqlen_q, seqlen_k) >= 2048: pytest.skip() if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if paged_kv_block_size is None: k = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) block_table = None else: k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad = flash_attn_varlen_func( q_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged, v_unpad if paged_kv_block_size is None else v_cache_paged, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, block_table=block_table, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 g = torch.randn_like(out) if is_bwd_hdim_supported(d): do_o = (g.float() * out.float()).sum(-1) test_backward = block_table is None if test_backward: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") if test_backward: # TODO - use 10 times to check, wait for ck to fix bwd precision issue assert (dq - dq_ref).abs().max().item() <= 10 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 10 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 10 * (dv_pt - dv_ref).abs().max().item() + 1e-5 # TODO - support splitkv # def test_flash_attn_splitkv # TODO - Support has_leftpad @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) @pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("alibi", [False, True]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("rotary_interleaved", [False, True]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), (1, 128 * 1024), (16, 128 * 1024), (128, 128), ], ) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, has_batch_idx, has_leftpad, paged_kv_block_size, rotary_fraction, rotary_interleaved, seqlen_new_eq_seqlen_q, causal, local, alibi, new_kv, mha_type, num_splits, dtype, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if has_batch_idx and paged_kv_block_size is not None: pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 1 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) if has_leftpad: cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size)]) else: cache_leftpad = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size ] else: cache_batch_idx = None if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad ) else: alibi_slopes, attn_bias = None, None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = ( torch.rand( seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() v_cache_ref = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new ) k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out = flash_attn_with_kvcache( q, k_cache if paged_kv_block_size is None else k_cache_paged, v_cache if paged_kv_block_size is None else v_cache_paged, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, block_table=block_table, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, num_splits=num_splits, ) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if paged_kv_block_size is None: k_cache_select = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ) v_cache_select = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ) else: k_cache_select = rearrange( k_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache_select = rearrange( v_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache_select, v_cache_ref) # mult = 3 if f16, bf16 need 4 mult = 4 if not alibi else 5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (239, 1), (3, 799), (799, 3), (1024, 128), (97, 97), (128, 128), (200, 200), (256, 256), (257, 257), (384, 384), (512, 512), (768, 768), # (1024, 1024), ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) g = torch.randn_like(out0) if dropout_p == 0 and is_bwd_hdim_supported(d): ( dq0, dk0, dv0, ) = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() for i in range(250): torch.random.manual_seed(42) out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) assert torch.equal(out, out0) assert torch.equal(lse, lse0) if dropout_p == 0: ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [16, 32, 64]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ # TODO - 1 or 2 might fail, need to check if seqlen == 1 or seqlen == 2: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 5 q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2) ] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out) out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 7 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( k_pt.grad - k_ref.grad ).abs().max().item() + 1e-3 assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( v_pt.grad - v_ref.grad ).abs().max().item() + 1e-3 @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 2 q, k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) for _ in range(3) ] out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") # So g is not contiguous g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt = rearrange(out_pt, "b s ... -> s b ...") out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( q_pt.grad - q_ref.grad ).abs().max().item() assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( k_pt.grad - k_ref.grad ).abs().max().item() assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( v_pt.grad - v_ref.grad ).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [16, 32, 64]) def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ device = "cuda" # set seed torch.random.manual_seed(0) nheads = 5 q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) Mq = 256 Mk = 3 q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) g = torch.randn_like(out) out.backward(g) assert not q.grad.isnan().any() assert not k.grad.isnan().any() assert not v.grad.isnan().any() @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("local", [False, True]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) @pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, deterministic=True, ) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) ================================================ FILE: tests/test_flash_attn_triton_amd.py ================================================ import math import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat from flash_attn import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb from aiter.ops.triton._triton_kernels.flash_attn_triton_amd.utils import USE_TRITON_ROCM, is_hip, get_arch def _get_block_size_n_triton(device, head_dim, is_dropout, is_causal): """Get block size for Triton AMD kernel.""" arch = get_arch() if arch.is_rdna: return 32 elif arch.is_cdna: return 64 # Fall back to CUDA kernel block sizes return _get_block_size_n(device, head_dim, is_dropout, is_causal) MAX_HEADDIM_SM8x = 192 is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) skip_bfloat16 = True if is_sm75 or is_hip() else False def attn_bias_from_alibi_slopes( slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None ): batch, nheads = slopes.shape device = slopes.device slopes = rearrange(slopes, "b h -> b h 1 1") if causal: return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes else: row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) relative_pos = torch.abs(row_idx + sk - sq - col_idx) return -slopes * relative_pos.to(dtype=slopes.dtype) def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(-1, -1), # -1 means infinite window size query_padding_mask=None, key_padding_mask=None, device=None, key_leftpad=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), col_idx < row_idx + sk - sq - window_size[0], ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if softcap > 0: scores = scores / softcap scores = scores.tanh() scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) def attention_kvpacked_ref( q, kv, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): return attention_ref( q, kv[:, :, 0], kv[:, :, 1], query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, upcast=upcast, causal=causal, window_size=window_size, softcap=softcap, reorder_ops=reorder_ops, key_leftpad=key_leftpad, ) def attention_qkvpacked_ref( qkv, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, ): return attention_ref( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, upcast=upcast, causal=causal, window_size=window_size, softcap=softcap, reorder_ops=reorder_ops, ) def generate_sparsity_mask(seqlen, sparsity=0.3): repeats = seqlen // 16 // 2 # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'), # torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'), # torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1) # mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1) nrow, ncol = seqlen // 16, seqlen // 256 mask = torch.rand(nrow, ncol, device="cuda") < sparsity return mask def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, head_dim) blockmask: (seqlen / 16, seqlen / 256) attn_mask: (batch_size, seqlen) dropout_p: float dropout_mask: (batch_size, nheads, seqlen, seqlen) Output: output: (batch_size, seqlen, nheads, head_dim) attention: softmax after dropout """ q, k, v = qkv.float().unbind(dim=2) d = qkv.shape[-1] seqlen = qkv.shape[1] scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") blockmask = blockmask[:seqlen, :seqlen] scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) attention = torch.softmax(scores, dim=-1) attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) def convert_flash_attn_S_to_softmax( S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """FlashAttention stores the S matrix in a different way. Arguments: S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) query_padding_mask: (batch_size, seqlen_q_rounded) key_padding_mask: (batch_size, seqlen_k_rounded) """ if causal: window_size = (window_size[0], 0) seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] S_converted = S if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, S.device, ) local_mask = F.pad( local_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True, ) S_converted = S_converted.masked_fill(local_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. seqlen_q_og = ( query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded ) if query_padding_mask is not None: query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k if key_padding_mask is not None: key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] def normalize_flash_attn_S( attn_unnorm, q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, is_dropout=False, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k, v: (batch_size, seqlen_k, nheads, head_dim) key_padding_mask: (batch_size, seqlen_q) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) Output: softmax_lse: (batch_size, nheads, seqlen_q) softmax_max: (batch_size, nheads, seqlen_q) """ if causal: window_size = (window_size[0], 0) q, k, v = q.float(), k.float(), v.float() _, seqlen_q, _, head_dim = q.shape seqlen_k = k.shape[1] scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) block_size_n = _get_block_size_n_triton(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. lse[lse == float("-inf")] = float("inf") scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_norm = torch.cat( [ a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") for a, m in zip(attn_unnorm_block, cummax_block) ], dim=-1, ) if query_padding_mask is not None: attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) return attn_norm.to(dtype=attn_unnorm.dtype) def get_dropout_fraction( dropout_mask, query_padding_mask=None, key_padding_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size ): """ dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) """ if causal: window_size = (window_size[0], 0) batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape dropped = ~dropout_mask valid = torch.ones_like(dropout_mask) if query_padding_mask is not None: dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) if key_padding_mask is not None: dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, dropout_mask.device, ) dropped.masked_fill_(local_mask, False) valid.masked_fill_(local_mask, False) dropped_total = dropped.sum() return dropped.sum() / valid.sum() @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) # @pytest.mark.parametrize("seqlen", [512]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) else: alibi_slopes, attn_bias = None, None out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], None, None, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) # v = qkv[:, :, 2].float() # qk = torch.einsum('bshd,bthd->bhst', qkv[:, :, 0], qkv[:, :, 1]).float() # if causal: # causal_mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=qkv.device), 1) # qk.masked_fill_(causal_mask, float('-inf')) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # p_tmp = torch.softmax(qk / math.sqrt(d), -1) # p_dropped = p_tmp if dropout_mask is None else p_tmp.masked_fill(~dropout_mask, 0) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # qk_max1 = torch.max(qk[:, :, 128:, 192:], -1, keepdim=True).values # qk_max2 = torch.max(qk[:, :, 128:, 128:], -1, keepdim=True).values # qk_max3 = torch.max(qk[:, :, 128:, 64:], -1, keepdim=True).values # qk_max4 = torch.max(qk[:, :, 128:, :], -1, keepdim=True).values # o1 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 192:] - qk_max1) / math.sqrt(d)), v[:, 192:]) # o2 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 128:] - qk_max2) / math.sqrt(d)), v[:, 128:]) # o3 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, 64:] - qk_max3) / math.sqrt(d)), v[:, 64:]) # o4 = torch.einsum('bhst,bthd->bshd', torch.exp((qk[:, :, 128:, :] - qk_max4) / math.sqrt(d)), v[:, :]) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) # do_o = (g.float() * out.float()).sum(-1) # dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64]) # dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:]) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): (dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked( seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype ): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 6 window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) qkv = torch.randn( batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True ) key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True ) out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( qkv_unpad, cu_seqlens, max_seqlen, dropout_p, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], key_padding_mask, key_padding_mask, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None out_ref, attn_ref = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_qkvpacked_ref( qkv, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [True, False]) # @pytest.mark.parametrize("kvpacked", [False]) @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: if causal: if seqlen_q ==1024 and seqlen_k==1024 and d==160: pytest.skip("This test with causal=True is flakey") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if softcap > 0: # Ensure the values of qk are at least within softcap range. q = q * softcap if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None if kvpacked: out, lse, S_dmask = flash_attn_kvpacked_func( q, kv, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: out, lse, S_dmask = flash_attn_func( q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) attn = normalize_flash_attn_S( attn_unnorm, q, k_rep, v_rep, None, None, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, None, None, causal=causal, window_size=window_size ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_ref( q, k, v, None, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): if kvpacked: ( dq, dkv, ) = torch.autograd.grad(out, (q, kv), g) dk, dv = dkv.unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("kvpacked", [False]) # @pytest.mark.parametrize('kvpacked', [False]) @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize('mha_type', ["mqa"]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 147), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("softcap", [0.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_output( seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if USE_TRITON_ROCM: if seqlen_q == 1 and seqlen_k == 147 and kvpacked == True and dropout_p != 0.0: pytest.skip("This config with dropout is flaky on AMD.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if softcap > 0.0 and dropout_p > 0.0: pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if softcap > 0: # Ensure the values of qk are at least within softcap range. q = q * softcap if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) else: k = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal ) else: alibi_slopes, attn_bias = None, None if kvpacked: ( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, kv, output_pad_fn, dq_pad_fn, dkv_pad_fn, ) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) else: ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, causal=causal, window_size=window_size, softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) k_rep, v_rep = kv_rep.unbind(dim=2) else: k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) attn = normalize_flash_attn_S( attn_unnorm, q, k_rep, v_rep, query_padding_mask, key_padding_mask, attn_bias, dropout_p > 0.0, causal=causal, window_size=window_size, ) dropout_fraction = get_dropout_fraction( dropout_mask, query_padding_mask, key_padding_mask, causal=causal, window_size=window_size, ).item() print(f"Actual dropout fraction: {dropout_fraction}") else: dropout_mask = None if kvpacked: out_ref, attn_ref = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, kv, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) else: out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size, softcap=softcap, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") if dropout_p > 0.0: print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): if kvpacked: ( dq_unpad, dkv_unpad, ) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) ( dq_ref, dkv_ref, ) = torch.autograd.grad(out_ref, (q, kv), g) dk_ref, dv_ref = dkv_ref.unbind(2) ( dq_pt, dkv_pt, ) = torch.autograd.grad(out_pt, (q, kv), g) dk_pt, dv_pt = dkv_pt.unbind(2) else: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() if dropout_p > 0.0: # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64, 128]) @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): if USE_TRITON_ROCM: if get_arch().is_rdna: if seqlen_q == 1 and seqlen_k == 239 and d == 256: pytest.skip("This config doesnot work on RDNA Devices.") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) out_ref, attn_ref = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_ref( q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged @pytest.mark.parametrize("paged_kv_block_size", [None]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" causal = True # set seed torch.random.manual_seed(0) batch_size = 8 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) if paged_kv_block_size is None: k = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) v = torch.randn( batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True ) block_table = None else: k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad = flash_attn_varlen_func( q_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged, v_unpad if paged_kv_block_size is None else v_cache_paged, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, block_table=block_table, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, ) out_pt, attn_pt = attention_ref( q, k, v, query_padding_mask, key_padding_mask, None, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) test_backward = block_table is None if test_backward: ( dq_unpad, dk_unpad, dv_unpad, ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dq = dq_pad_fn(dq_unpad) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 if test_backward: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [True]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (3, 1024), (1, 339), (64, 800), (3, 799), (64, 2048), (16, 20000), (16, 100000), (128, 128), (256, 256), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.skip() def test_flash_attn_splitkv( seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype ): if USE_TRITON_ROCM: if seqlen_q == 1 and seqlen_k == 339 and swap_sq_sk == True: pytest.skip("This config with is flaky on AMD.") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 1 nheads = 12 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) else: alibi_slopes, attn_bias = None, None out, lse, _ = flash_attn_func( q, k, v, 0.0, causal=causal, window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, ) out_ref, attn_ref = attention_ref( q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size ) out_pt, attn_pt = attention_ref( q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) ( dq_ref, dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 mult = 2 if not alibi else 8 assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 # @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("num_splits", [1, 0]) # @pytest.mark.parametrize("num_splits", [1]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("rotary_interleaved", [False, True]) # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_leftpad", [True]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 128), (1, 339), (3, 1024), (64, 800), (64, 256), (3, 799), (64, 2048), (16, 20000), (1, 128 * 1024), (16, 128 * 1024), (128, 128), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, has_batch_idx, has_leftpad, paged_kv_block_size, rotary_fraction, rotary_interleaved, seqlen_new_eq_seqlen_q, causal, local, alibi, new_kv, mha_type, num_splits, dtype, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() if not new_kv and rotary_fraction > 0.0: pytest.skip() if has_batch_idx and paged_kv_block_size is not None: pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks, ) = _generate_block_kvcache( seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ) cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough ( (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) if new_kv else (seqlen_k + 1) ), (batch_size,), dtype=torch.int32, device=device, ) if has_leftpad: cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) for i in range(batch_size)]) else: cache_leftpad = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) if has_leftpad: key_padding_mask = torch.logical_and( key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) ) if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size ] else: cache_batch_idx = None if alibi: alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 attn_bias = attn_bias_from_alibi_slopes( alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad ) else: alibi_slopes, attn_bias = None, None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = ( torch.rand( seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, rotary_dim // 2, device=device, ) * 2 * math.pi ) cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) if causal or local: q_ro = apply_rotary_emb( q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: q_ro = rearrange( apply_rotary_emb( rearrange(q, "b s h d -> b 1 (s h) d"), cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved, ), "b 1 (s h) d -> b s h d", s=seqlen_q, ) # q_ro = q k_ro = apply_rotary_emb( k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 k_cache_ref = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() v_cache_ref = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ).clone() if new_kv: update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new ) k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) out = flash_attn_with_kvcache( q, k_cache if paged_kv_block_size is None else k_cache_paged, v_cache if paged_kv_block_size is None else v_cache_paged, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, cache_leftpad=cache_leftpad, block_table=block_table, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, alibi_slopes=alibi_slopes, num_splits=num_splits, ) # out = flash_attn_with_kvcache( # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size # ) # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) # probs = torch.softmax(qk, dim=-1) out_ref, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, attn_bias, 0.0, None, causal=causal, window_size=window_size, upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: if paged_kv_block_size is None: k_cache_select = ( k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] ) v_cache_select = ( v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] ) else: k_cache_select = rearrange( k_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache_select = rearrange( v_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) assert torch.equal(v_cache_select, v_cache_ref) mult = 3 if not alibi else 5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype ) v_cache_paged = torch.randn( num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype ) block_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", b=batch_size, ) k_cache = rearrange( # pytorch 1.12 doesn't have indexing with int32 k_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] v_cache = rearrange( v_cache_paged[block_table.to(dtype=torch.long).flatten()], "(b nblocks) block_size ... -> b (nblocks block_size) ...", b=batch_size, )[:, :seqlen_k] return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks # @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [128]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (239, 1), (3, 799), (799, 3), (1024, 128), (97, 97), (128, 128), (200, 200), (256, 256), (257, 257), (384, 384), (512, 512), (768, 768), (1024, 1024), ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.0]) @pytest.mark.skip() def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) torch.random.manual_seed(42) out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) g = torch.randn_like(out0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq0, dk0, dv0, ) = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() for i in range(250): torch.random.manual_seed(42) out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) assert torch.equal(out, out0) assert torch.equal(lse, lse0) if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): ( dq, dk, dv, ) = torch.autograd.grad(out, (q, k, v), g) dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert dq_equal @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) @pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) # @pytest.mark.parametrize('seqlen', [2]) @pytest.mark.skip() def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0. """ device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 5 q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2) ] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_func(q, k, v, causal=causal) g = torch.randn_like(out) out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( q_pt.grad - q_ref.grad ).abs().max().item() + 1e-3 assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( k_pt.grad - k_ref.grad ).abs().max().item() + 1e-3 assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( v_pt.grad - v_ref.grad ).abs().max().item() + 1e-3 @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) # @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.skip() def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): """We previously had a bug where we were using the wrong strides of dout, which shows up when dout is not contiguous. """ device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 5 nheads = 2 q, k, v = [ torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) for _ in range(3) ] out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") # So g is not contiguous g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] out.backward(g) q_pt = q.detach().clone().requires_grad_(True) k_pt = k.detach().clone().requires_grad_(True) v_pt = v.detach().clone().requires_grad_(True) out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) out_pt = rearrange(out_pt, "b s ... -> s b ...") out_pt.backward(g) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) v_ref = v.detach().clone().requires_grad_(True) out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) out_ref = rearrange(out_ref, "b s ... -> s b ...") out_ref.backward(g) print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( q_pt.grad - q_ref.grad ).abs().max().item() assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( k_pt.grad - k_ref.grad ).abs().max().item() assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( v_pt.grad - v_ref.grad ).abs().max().item() @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [16, 32, 64]) # @pytest.mark.parametrize('d', [16]) @pytest.mark.skip() def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): """We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, in the case where seqlen % 128 != 0 or varlen. """ device = "cuda" # set seed torch.random.manual_seed(0) nheads = 5 q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) Mq = 256 Mk = 3 q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) g = torch.randn_like(out) out.backward(g) assert not q.grad.isnan().any() assert not k.grad.isnan().any() assert not v.grad.isnan().any() @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.skip() def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) @pytest.mark.parametrize("dtype", ([torch.float16] if skip_bfloat16 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("local", [True]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ (1, 239), (3, 799), (127, 512), (127, 513), (113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (1023, 1024), ], ) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) @pytest.mark.skip() def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 2 nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q, k, v, output_pad_fn, dq_pad_fn, dk_pad_fn, ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out = flash_attn_varlen_func( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, causal=causal, window_size=window_size, deterministic=True, ) g = torch.randn_like(out) dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) ================================================ FILE: tests/test_rotary.py ================================================ import math import random import pytest import torch import torch.nn.functional as F from einops import rearrange import triton from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_ from flash_attn.bert_padding import pad_input, unpad_input is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0) def generate_cos_sin(seqlen, rotary_dim, device, dtype): assert rotary_dim % 2 == 0 angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi cos = torch.cos(angle).to(dtype=dtype) sin = torch.sin(angle).to(dtype=dtype) return cos, sin def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device): if seqlen_offsets_type == 0: return 0 elif seqlen_offsets_type is int: return torch.randint(0, seqlen + 1, (1,)).item() elif seqlen_offsets_type is torch.Tensor: return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device) def index_cos_sin(cos, sin, seqlen_offsets, seqlen): if isinstance(seqlen_offsets, torch.Tensor): batch_size = seqlen_offsets.shape[0] arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s") idx = rearrange(seqlen_offsets, "b -> b 1") + arange cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size) sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size) else: cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen] sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen] return cos_pt, sin_pt @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.bfloat16])) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) # @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) # @pytest.mark.parametrize('rotary_fraction', [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize('interleaved', [True]) @pytest.mark.parametrize("inplace", [False, True]) # @pytest.mark.parametrize('inplace', [False]) def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype): rtol = 1e-3 batch_size = 32 nheads = 4 seqlen = 217 headdim = 128 device = "cuda" rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) x = torch.randn( batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) x_pt = x.detach().clone().requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) out = apply_rotary_emb( x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace ) cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) out_pt = apply_rotary_emb_torch( x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) print(f"Output max diff: {(out - out_pt).abs().max().item()}") g = torch.randn_like(out) g_pt = g.clone() # If inplace=True, we might modify the gradient inplace out.backward(g) out_pt.backward(g_pt) print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}") if not inplace: assert torch.equal(x, x_pt) # Numerical error if we just do any arithmetic atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol) @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.float16])) @pytest.mark.parametrize("compiled", [False, True]) # @pytest.mark.parametrize("compiled", [True]) @pytest.mark.parametrize("gqa", [False, True]) # @pytest.mark.parametrize("gqa", [False]) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) # @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) # @pytest.mark.parametrize('rotary_fraction', [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize('interleaved', [False]) def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, compiled, dtype): if compiled: # Don't fall back to eager just bc of recompilation torch._dynamo.config.recompile_limit = 2 ** 31 rtol = 1e-3 batch_size = 32 nheads = 4 seqlen = 512 headdim = 128 device = "cuda" rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) if not gqa: qkv = torch.randn( batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) else: nheads_k = nheads // 2 qkv = torch.randn( batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True ) qkv_pt = qkv.detach().clone().requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) fn = apply_rotary_emb_qkv_ if not compiled else torch.compile(apply_rotary_emb_qkv_) out = fn( qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, num_heads_q=None if not gqa else nheads ) cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) if not gqa: q_pt, k_pt, v_pt = qkv_pt.unbind(2) else: q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2) q_pt = apply_rotary_emb_torch( q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) k_pt = apply_rotary_emb_torch( k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) if not gqa: out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2) else: out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2) print(f"Output max diff: {(out - out_pt).abs().max().item()}") g = torch.randn_like(out) g_pt = g.clone() # Since inplace=True, we modify the gradient inplace out.backward(g) out_pt.backward(g_pt) print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}") # Numerical error if we just do any arithmetic atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item() assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol) @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize('dtype', ([torch.float16])) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) # @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) # @pytest.mark.parametrize('rotary_fraction', [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize('interleaved', [False]) def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype): rtol = 1e-3 batch_size = 32 nheads = 4 seqlen = 781 headdim = 64 device = "cuda" rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) kv = torch.randn( batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True ) kv_pt = kv.detach().clone().requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved) cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) k_pt = apply_rotary_emb_torch( kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2) print(f"Output max diff: {(out - out_pt).abs().max().item()}") g = torch.randn_like(out) g_pt = g.clone() # Since inplace=True, we modify the gradient inplace out.backward(g) out_pt.backward(g_pt) print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}") # Numerical error if we just do any arithmetic atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item() assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol) @pytest.mark.parametrize( "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]) ) # @pytest.mark.parametrize("dtype", ([torch.float16])) @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor]) # @pytest.mark.parametrize("seqlen_offsets_type", [0]) @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5]) # @pytest.mark.parametrize("rotary_fraction", [1.0]) @pytest.mark.parametrize("interleaved", [False, True]) # @pytest.mark.parametrize("interleaved", [True]) @pytest.mark.parametrize("inplace", [False, True]) # @pytest.mark.parametrize("inplace", [False]) def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype): rtol = 1e-3 batch_size = 32 nheads = 4 seqlen = 217 headdim = 128 device = "cuda" rotary_dim = int(rotary_fraction * headdim) torch.manual_seed(42) x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) x_pt = x.detach().clone().requires_grad_() lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask) x_unpad_clone = x_unpad.clone() x_unpad = x_unpad.requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device) out_unpad = apply_rotary_emb( x_unpad, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) out = pad_input(out_unpad, indices, batch_size, seqlen) cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen) out_pt = apply_rotary_emb_torch( x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved ).to(dtype=dtype) out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0) print(f"Output max diff: {(out - out_pt).abs().max().item()}") g = torch.randn_like(out) g_pt = g.clone() # If inplace=True, we might modify the gradient inplace out.backward(g) out_pt.backward(g_pt) x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen) print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}") if not inplace: assert torch.equal(x_unpad, x_unpad_clone) # Numerical error if we just do any arithmetic atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item() assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol) def test_compilation_count(): nheads = 4 headdim = 128 device = "cuda" dtype = torch.float16 torch.manual_seed(42) from triton.runtime.jit import JITFunction from flash_attn.ops.triton.rotary import rotary_kernel compilation_count = 0 def count_compilations(*args, **kwargs): nonlocal compilation_count compilation_count += 1 old_cache_func = JITFunction.cache_hook try: if hasattr(rotary_kernel, "cache"): rotary_kernel.cache.clear() else: # Triton 3.3 replaces cache with per-device device_caches device = triton.runtime.driver.active.get_current_device() # device_caches[device] returns a 4-tuple: (kernel_cache, target, backend, binder) rotary_kernel.device_caches[device][0].clear() JITFunction.cache_hook = count_compilations for seqlen in (128, 256): for batch_size in (4, 32): x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) x.requires_grad_() cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) out = apply_rotary_emb(x, cos, sin) out.backward(torch.randn_like(out)) # Only two kernels are expected to be compiled: # * for the forward pass (conjugate=False) # * for the backward pass (conjugate=True) assert compilation_count == 2 finally: JITFunction.cache_hook = old_cache_func ================================================ FILE: tests/test_util.py ================================================ import math import torch from einops import rearrange, repeat from flash_attn.bert_padding import pad_input, unpad_input def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) if zero_lengths: # Generate zero-lengths every 5 batches and the last batch. for i in range(batch_size): if i % 5 == 0: lengths[i] = 0 lengths[-1] = 0 padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) return padding_mask def generate_qkv( q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, add_unused_qkv=False, query_unused_mask=None, key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) v: (batch_size, seqlen_k, nheads_k, d) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked if query_padding_mask is not None: q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask, ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: assert (query_padding_mask == key_padding_mask).all() assert nheads == nheads_k qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) qkv = torch.stack([q, k, v], dim=2) if query_padding_mask is not None: dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) else: dqkv_pad_fn = lambda dqkv_unpad: rearrange( dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( qkv_unpad.detach().requires_grad_(), cu_seqlens_q, max_seqlen_q, qkv.detach().requires_grad_(), output_pad_fn, dqkv_pad_fn, ) elif kvpacked: kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) kv = torch.stack([k, v], dim=2) dq_pad_fn = output_pad_fn if key_padding_mask is not None: dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) else: dkv_pad_fn = lambda dkv_unpad: rearrange( dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size ) return ( q_unpad.detach().requires_grad_(), kv_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), kv.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dkv_pad_fn, ) else: dq_pad_fn = output_pad_fn if key_padding_mask is not None: dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) else: dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) return ( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), output_pad_fn, dq_pad_fn, dk_pad_fn, ) def construct_local_mask( seqlen_q, seqlen_k, window_size=(-1, -1), # -1 means infinite window size query_padding_mask=None, key_padding_mask=None, device=None, key_leftpad=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) sq = ( seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") ) if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk return torch.logical_or( col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), col_idx < row_idx + sk - sq - window_size[0], ) def attention_ref( q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) if softcap > 0: scores /= softcap scores = scores.tanh() scores *= softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) if key_padding_mask is not None: output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) ================================================ FILE: tools/sass_diff.py ================================================ #!/usr/bin/env python3 """Compare two SASS files, ignoring register assignments and addresses. Normalizes registers per-instruction so that two instructions doing the same operation with different register allocations compare as equal. E.g. "UIADD3 UR30, UP1, UR30, 0x70, URZ" and "UIADD3 UR14, UP1, UR38, 0x70, URZ" both normalize to "UIADD3 UR_0, UP_0, UR_1, 0x70, URZ" Usage: python scripts/sass_diff.py file_a.sass file_b.sass python scripts/sass_diff.py file_a.sass file_b.sass --context 5 python scripts/sass_diff.py file_a.sass file_b.sass --all # include metadata python scripts/sass_diff.py file_a.sass file_b.sass --summary-only """ import argparse import re import sys from dataclasses import dataclass, field from difflib import SequenceMatcher # ── Parsing ────────────────────────────────────────────────────────────────── ADDR_LINE_RE = re.compile(r"^\s+/\*([0-9a-f]+)\*/\s+(.*?)\s*;?\s*$") LABEL_RE = re.compile(r"^(\.L_x_\d+):\s*$") METADATA_PREFIXES = (".byte", ".word", ".short", ".dword", ".string", ".align") # Register pattern: match UR before R, UP before P REG_RE = re.compile(r"\b(UP|UR|P|R)(\d+)\b") @dataclass class Line: """One parsed SASS line.""" addr: str # hex address or "" for labels raw: str # original text (no addr prefix) normalized: str # register-normalized for comparison lineno: int # 1-based line number in file is_code: bool # True for instructions/labels def _normalize_instr(text: str) -> str: """Normalize one instruction by replacing registers with positional IDs. Each register class (R, UR, P, UP) gets its own counter, reset per instruction. Constants RZ, URZ, PT, UPT are preserved. """ counters: dict[str, int] = {} mapping: dict[str, str] = {} def repl(m: re.Match) -> str: name = m.group(0) if name in ("RZ", "URZ", "PT", "UPT"): return name if name in mapping: return mapping[name] prefix = m.group(1) idx = counters.get(prefix, 0) counters[prefix] = idx + 1 mapping[name] = f"{prefix}_{idx}" return mapping[name] return REG_RE.sub(repl, text) def parse_sass(path: str) -> list[Line]: """Extract instruction, label, and metadata lines from a SASS file.""" lines: list[Line] = [] with open(path) as f: for lineno, raw in enumerate(f, 1): raw = raw.rstrip() m = LABEL_RE.match(raw) if m: label = m.group(1) lines.append(Line("", label, label, lineno, True)) continue m = ADDR_LINE_RE.match(raw) if m: addr, text = m.group(1), m.group(2).strip() is_meta = any(text.startswith(p) for p in METADATA_PREFIXES) normalized = text if is_meta else _normalize_instr(text) lines.append(Line(addr, text, normalized, lineno, not is_meta)) return lines # ── Diffing ────────────────────────────────────────────────────────────────── @dataclass class DiffBlock: tag: str # "equal", "replace", "insert", "delete" a_lines: list[Line] = field(default_factory=list) b_lines: list[Line] = field(default_factory=list) def diff_sass(a_lines: list[Line], b_lines: list[Line]) -> list[DiffBlock]: a_norm = [l.normalized for l in a_lines] b_norm = [l.normalized for l in b_lines] sm = SequenceMatcher(None, a_norm, b_norm, autojunk=False) blocks: list[DiffBlock] = [] for tag, i1, i2, j1, j2 in sm.get_opcodes(): blocks.append(DiffBlock(tag, a_lines[i1:i2], b_lines[j1:j2])) return blocks # ── Display ────────────────────────────────────────────────────────────────── RED = "\033[31m" GREEN = "\033[32m" CYAN = "\033[36m" DIM = "\033[2m" RESET = "\033[0m" def _fmt(line: Line, prefix: str, color: str, use_color: bool, show_norm: bool) -> str: addr = f"[{line.addr}]" if line.addr else " " text = line.normalized if show_norm else line.raw if use_color: return f"{color}{prefix} {addr:>8s} {text}{RESET}" return f"{prefix} {addr:>8s} {text}" def print_diff(blocks: list[DiffBlock], context: int = 3, use_color: bool = True, show_norm: bool = False): """Unified-diff-style output with context.""" groups: list[list[str]] = [] cur: list[str] = [] last_changed = False for block in blocks: if block.tag == "equal": lines = block.a_lines if last_changed: for l in lines[:context]: cur.append(_fmt(l, " ", DIM, use_color, show_norm)) if len(lines) > 2 * context: if cur: groups.append(cur) cur = [] for l in lines[-context:]: cur.append(_fmt(l, " ", DIM, use_color, show_norm)) elif len(lines) > context: for l in lines[context:]: cur.append(_fmt(l, " ", DIM, use_color, show_norm)) else: for l in lines[-context:]: cur.append(_fmt(l, " ", DIM, use_color, show_norm)) last_changed = False else: last_changed = True if block.tag in ("replace", "delete"): for l in block.a_lines: cur.append(_fmt(l, "-", RED, use_color, show_norm)) if block.tag in ("replace", "insert"): for l in block.b_lines: cur.append(_fmt(l, "+", GREEN, use_color, show_norm)) if cur: groups.append(cur) sep = f"{CYAN}{'─' * 72}{RESET}" if use_color else "─" * 72 for i, g in enumerate(groups): if i > 0: print(sep) for line in g: print(line) def _get_opcode(raw: str) -> str | None: """Extract opcode from instruction, skipping predicates and labels.""" for p in raw.split(): if p.startswith("@") or p.startswith(".L_"): continue return p return None def print_summary(a_all: list[Line], b_all: list[Line], blocks: list[DiffBlock]): a_code = [l for l in a_all if l.is_code] b_code = [l for l in b_all if l.is_code] n_equal = sum(len(b.a_lines) for b in blocks if b.tag == "equal") n_delete = sum(len(b.a_lines) for b in blocks if b.tag in ("replace", "delete")) n_insert = sum(len(b.b_lines) for b in blocks if b.tag in ("replace", "insert")) n_changed = sum(1 for b in blocks if b.tag != "equal") print(f" File A: {len(a_code)} instructions") print(f" File B: {len(b_code)} instructions") print(f" Identical (normalized): {n_equal}") print(f" Changed regions: {n_changed}") print(f" Removed: {n_delete}, Added: {n_insert}") def opcode_counts(lines): counts: dict[str, int] = {} for l in lines: op = _get_opcode(l.raw) if op: counts[op] = counts.get(op, 0) + 1 return counts a_ops, b_ops = opcode_counts(a_code), opcode_counts(b_code) all_ops = sorted(set(a_ops) | set(b_ops)) diffs = {op: b_ops.get(op, 0) - a_ops.get(op, 0) for op in all_ops} diffs = {op: d for op, d in diffs.items() if d != 0} if diffs: print("\n Opcode count changes (B - A):") for op, d in sorted(diffs.items(), key=lambda x: -abs(x[1])): sign = "+" if d > 0 else "" print(f" {op:30s} {sign}{d}") else: print("\n Opcode counts: identical") # ── Main ───────────────────────────────────────────────────────────────────── def main(): p = argparse.ArgumentParser(description="Compare SASS files ignoring register assignments") p.add_argument("file_a", help="First SASS file") p.add_argument("file_b", help="Second SASS file") p.add_argument("-C", "--context", type=int, default=3, help="Context lines (default: 3)") p.add_argument("--no-color", action="store_true", help="Disable color output") p.add_argument("--summary-only", action="store_true", help="Only print summary") p.add_argument("--all", action="store_true", help="Include metadata in diff") p.add_argument("--show-normalized", action="store_true", help="Show normalized form instead of raw instructions") args = p.parse_args() a_all = parse_sass(args.file_a) b_all = parse_sass(args.file_b) if args.all: a_lines, b_lines = a_all, b_all else: a_lines = [l for l in a_all if l.is_code] b_lines = [l for l in b_all if l.is_code] blocks = diff_sass(a_lines, b_lines) use_color = not args.no_color and sys.stdout.isatty() print("=== Summary ===") print_summary(a_all, b_all, blocks) if not args.summary_only: print("\n=== Diff (registers normalized) ===\n") print_diff(blocks, args.context, use_color, args.show_normalized) if __name__ == "__main__": main() ================================================ FILE: training/Dockerfile ================================================ # Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile # ARG COMPAT=0 ARG PERSONAL=0 # FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0 FROM nvcr.io/nvidia/pytorch:22.12-py3 as base ENV HOST docker ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 # https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes ENV TZ America/Los_Angeles RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone # git for installing dependencies # tzdata to set time zone # wget and unzip to download data # [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment. # [2021-12-07] TD: openmpi-bin for MPI (multi-node training) RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ cmake \ curl \ ca-certificates \ sudo \ less \ htop \ git \ tzdata \ wget \ tmux \ zip \ unzip \ zsh stow subversion fasd \ && rm -rf /var/lib/apt/lists/* # openmpi-bin \ # Allow running runmpi as root # ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 # # Create a non-root user and switch to it # RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ # && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user # USER user # All users can use /home/user as their home directory ENV HOME=/home/user RUN mkdir -p /home/user && chmod 777 /home/user WORKDIR /home/user # Set up personal environment # FROM base-${COMPAT} as env-0 FROM base as env-0 FROM env-0 as env-1 # Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image # https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile ONBUILD COPY dotfiles ./dotfiles ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami) # nvcr pytorch image sets SHELL=/bin/bash ONBUILD ENV SHELL=/bin/zsh FROM env-${PERSONAL} as packages # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for ENV PIP_NO_CACHE_DIR=1 # # apex and pytorch-fast-transformers take a while to compile so we install them first # TD [2022-04-28] apex is already installed. In case we need a newer commit: # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex # xgboost conflicts with deepspeed RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7 # General packages that we don't care about the version # zstandard to extract the_pile dataset # psutil to get the number of cpu physical cores # twine to upload package to PyPI RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \ && python -m spacy download en_core_web_sm # hydra RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich # Core packages RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 triton==2.0.0.dev20221202 wandb==0.13.7 timm==0.6.12 torchmetrics==0.10.3 # torchmetrics 0.11.0 broke hydra's instantiate # For MLPerf RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention RUN pip install flash-attn==2.6.3 # Install CUDA extensions for fused dense RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib ================================================ FILE: training/README.md ================================================ # Optimized Transformer implementation This repo contains examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT) and trained end-to-end. We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x compared to the baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need any activation checkpointing). All without changing the model architecture (i.e., no approximation). Goals: - Performance: we optimize for model speed and memory, especially on 1-node (e.g., with 8 A100s). - Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm), and the model code illustrates how these components can be put together. The training code also aims to be model- & task-agnostic. Non-goals (and other resources): - Support as many models as possible: Huggingface's [transformers](https://github.com/huggingface/transformers) and [timm](https://github.com/rwightman/pytorch-image-models/) are great for this. - Large-scale distributed training: our codebase has been used for multi-GPU and multi-node training for models up to 2.7B parameters. However, if you're looking for large-scale distributed training techniques (e.g., pipeline parallelism, tensor parallelism), check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and [DeepSpeed](https://github.com/microsoft/deepspeed). - Inference: we currently focus on training (this might change in the future). If you want fast inference, take a look at [FasterTransformer](https://github.com/NVIDIA/FasterTransformer). - Production: this codebase was written during several research projects to validate ideas on speeding up ML models. ## Model Components The GPT model is implemented [here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py). And here's an example to construct the GPT3-1.3B model with rotary embedding: ```python from transformers.models.gpt2.configuration_gpt2 import GPT2Config from flash_attn.models.gpt import GPTLMHeadModel seqlen = 2048 hidden_dim = 2048 nheads = 16 n_layer = 24 rotary_emb_fraction = 0.5 config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim, n_layer=n_layer, n_head=nheads, scale_attn_by_inverse_layer_idx=True, rotary_emb_fraction=rotary_emb_fraction, use_flash_attn=True, fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True, pad_vocab_size_multiple=8) model = GPTLMHeadModel(config) ``` We provide the following optimized components: 1. FlashAttention: fast and memory-efficient exact attention. This makes attention much faster and saves a lot of activation memory. As a result we don't need to use any activation checkpointing. ```sh pip install flash-attn ``` 2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu (forward and backward), adapted from Apex's [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before this doesn't have the best matmul + bias + gelu performance for bfloat16. ```sh cd ../csrc/fused_dense_lib && pip install . ``` 3. Optimized cross-entropy loss, adapted from Apex's [Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory. ```sh cd ../csrc/xentropy && pip install . ``` 4. Fused rotary embedding: ```sh cd ../csrc/rotary && pip install . ``` 5. Fused dropout + residual + LayerNorm, adapted from Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. This supports dimensions divisible by 8, up to 6144. ```sh cd ../csrc/layer_norm && pip install . ``` ## Training We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples. Feel free to use the model in your own training setup as well. We use [Hydra](https://hydra.cc/) for configuration, [Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and [Wandb](https://wandb.ai/) for logging. We use the template from `https://github.com/ashleve/lightning-hydra-template`. Please read the instructions there to understand the repo structure. ### Requirements Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core, hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn. We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) We provide a Dockerfile that lists all the required packages. ### Dataset preparation Running the training command would automatically download the datasets (Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the tokens, then save this cache to disk. Alternatively, you can also prepare the datasets as a separate step. The cached datasets are saved to `${DATA_DIR}/openwebtext` and `${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to `./data/{openwebtext,the_pile}`. - Openwebtext: ```sh export PYTHONPATH=$PWD:$PYTHONPATH pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext" ``` This takes around 1h on a 64-core CPU. The processed dataset has size 17GB. - The Pile: ```sh export PYTHONPATH=$PWD:$PYTHONPATH pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile" ``` This takes around 20h on a 64-core CPU. The processed dataset has size 699GB. ### GPT2 training on Openwebtext To train GPT2 on Openwebtext with 8 GPUs: ```sh python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B ``` The default parameters are set for 8 x A100 80GB. To train with bf16 instead of fp16, add `trainer.precision=bf16`. ### GPT3 training on The Pile To train GPT3 on The Pile with 8 GPUs: ```sh python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B ``` The default parameters are set for 8 x A100 80GB. We train with bf16 by default. To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl}-flash-rotary`. ### Training options **Gradient accumulation**: to adjust device batch size to fit into GPU memory (the global batch size stays the same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`. **Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`. **Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`. **Resumable training**: set a name to the run, and then set `resume=True` when you resume. Training will restart at exactly the same batch. ```sh python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True ``` ## Training speed We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink. FLOPs are calculated using the formula from the [Megatron-LM paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4 to get the model FLOPs (instead of hardware FLOPs with activation checkpointing). ### GPT2 (sequence length 1024) ![GPT2 speedup](../assets/gpt2_training_efficiency.jpg) The implementation in this repo (FlashAttention) is 3-4x faster than the baseline implementation from Huggingface. ### GPT3 (sequence length 2048) ![GPT3 speedup](../assets/gpt3_training_efficiency.jpg) The implementation in this repo (FlashAttention) is 3-5x faster than the baseline implementation from Huggingface. For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency. We include here more details on the training speed with FlashAttention on 8 x A100 80GB. | Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens | | --------- | ------------------- | ------------------------ | ----------------- | | GPT3-125M | 0.5M | 1310k | 0.21 | | GPT3-355M | 0.5M | 503k | 0.55 | | GPT3-760M | 0.5M | 245k | 1.13 | | GPT3-1.3B | 1M | 169k | 1.64 | | GPT3-2.7B | 1M | 85k | 3.27 | As an example, this means that one can train a GPT3-1.3B model on 26B tokens (compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100. ## Training quality We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens. For GPT2, the runs with FlashAttention yield the same loss curve as the runs with the baseline implementation from Huggingface for 125M and 355M models. For larger models the baseline implementation just takes too long. ![GPT2 training curve](../assets/gpt2_training_curve.jpg) We include here the loss curve for GPT3 on The Pile, trained for 400B tokens. The 125M, 355M, 760M models have batch size 512k tokens so this translates to 800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens, which translates to 400k training steps. ![GPT3 training curve](../assets/gpt3_training_curve.jpg) ================================================ FILE: training/configs/callbacks/causality-monitor.yaml ================================================ causality-monitor: _target_: src.callbacks.causality_monitor.CausalityMonitor ================================================ FILE: training/configs/callbacks/default.yaml ================================================ # rich_progress_bar: # _target_: pytorch_lightning.callbacks.RichProgressBar rich_model_summary: _target_: pytorch_lightning.callbacks.RichModelSummary model_checkpoint: _target_: pytorch_lightning.callbacks.ModelCheckpoint monitor: "val/acc" # name of the logged metric which determines when model is improving mode: "max" # can be "max" or "min" save_top_k: 1 # save k best models (determined by above metric) save_last: True # additionally always save model from last epoch verbose: False dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''} filename: "epoch_{epoch:03d}" auto_insert_metric_name: False early_stopping: _target_: pytorch_lightning.callbacks.EarlyStopping monitor: "val/acc" # name of the logged metric which determines when model is improving mode: "max" # can be "max" or "min" patience: 100 # how many epochs of not improving until training stops min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement learning_rate_monitor: _target_: pytorch_lightning.callbacks.LearningRateMonitor logging_interval: step speed_monitor: _target_: src.callbacks.speed_monitor.SpeedMonitor intra_step_time: True inter_step_time: True epoch_time: True loss_scale_monitor: _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor params_log: _target_: src.callbacks.params_log.ParamsLog total_params_log: True trainable_params_log: True non_trainable_params_log: True gpu_affinity: _target_: src.callbacks.gpu_affinity.GpuAffinity ================================================ FILE: training/configs/callbacks/ema.yaml ================================================ ema: _target_: src.callbacks.ema.EMACallback decay: ??? use_num_updates: False ================================================ FILE: training/configs/callbacks/flop-count.yaml ================================================ flop_count: _target_: src.callbacks.flop_count.FlopCount profilers: ['fvcore'] input_size: [3, 224, 224] device: null ================================================ FILE: training/configs/callbacks/gpu-monitor.yaml ================================================ defaults: - default.yaml gpu_stats_monitor: _target_: pytorch_lightning.callbacks.GPUStatsMonitor # [2021-08-13] TD: I just want the intra_step_size but it'll error if I # don't have memory_utilization and gpu_utilization. # Maybe I should write a callback with just the intra_step_size. memory_utilization: True gpu_utilization: True intra_step_time: True ================================================ FILE: training/configs/callbacks/model-summary.yaml ================================================ model_summary: _target_: pytorch_lightning.callbacks.RichModelSummary ================================================ FILE: training/configs/callbacks/none.yaml ================================================ ================================================ FILE: training/configs/callbacks/norm-monitor.yaml ================================================ norm_monitor: _target_: src.callbacks.norm_monitor.NormMonitor ================================================ FILE: training/configs/callbacks/params-log.yaml ================================================ params_log: _target_: src.callbacks.params_log.ParamsLog total_params_log: True trainable_params_log: True non_trainable_params_log: True ================================================ FILE: training/configs/callbacks/wandb.yaml ================================================ defaults: - default.yaml watch_model: _target_: src.callbacks.wandb_callbacks.WatchModel log: "all" log_freq: 100 upload_code_as_artifact: _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact code_dir: ${work_dir}/src upload_ckpts_as_artifact: _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact ckpt_dir: "checkpoints/" upload_best_only: True log_f1_precision_recall_heatmap: _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap log_confusion_matrix: _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix log_image_predictions: _target_: src.callbacks.wandb_callbacks.LogImagePredictions num_samples: 8 ================================================ FILE: training/configs/config.yaml ================================================ # @package _global_ # specify here default training configuration defaults: - _self_ - trainer: default - optimizer: adamw - scheduler: null - task: sequence-model - model: null - datamodule: null - callbacks: default # set this to null if you don't want to use callbacks - metrics: null - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) - mode: default - experiment: null - hparams_search: null # enable color logging - override hydra/hydra_logging: colorlog - override hydra/job_logging: colorlog # path to original working directory # hydra hijacks working directory by changing it to the current log directory, # so it's useful to have this path as a special variable # https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory work_dir: ${hydra:runtime.cwd} # path to folder with data data_dir: ${work_dir}/data/ # pretty print config at the start of the run using Rich library print_config: True # disable python warnings if they annoy you ignore_warnings: True # check performance on test set, using the best model achieved during training # lightning chooses best model based on metric specified in checkpoint callback test_after_training: True resume: False # seed for random number generators in pytorch, numpy and python.random seed: null # name of the run, accessed by loggers name: null ================================================ FILE: training/configs/datamodule/openwebtext.yaml ================================================ _target_: src.datamodules.language_modeling_hf.LMDataModule dataset_name: openwebtext dataset_config_name: null tokenizer_name: gpt2 cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache max_length: 1024 val_ratio: 0.0005 val_split_seed: 2357 add_eos: True batch_size: 8 # per GPU batch_size_eval: ${eval:${.batch_size} * 2} num_workers: 32 # For preprocessing only shuffle: True pin_memory: True __train_len: ${div_up:9035582198, ${.max_length}} ================================================ FILE: training/configs/datamodule/thepile.yaml ================================================ _target_: src.datamodules.language_modeling_hf.LMDataModule dataset_name: the_pile dataset_config_name: null tokenizer_name: gpt2 cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache max_length: 2048 add_eos: True batch_size: 4 # per GPU batch_size_eval: ${eval:${.batch_size} * 2} num_workers: 64 # For preprocessing only use_shmem: False shuffle: True pin_memory: True __train_len: ${div_up:374337375694, ${.max_length}} ================================================ FILE: training/configs/experiment/owt/base.yaml ================================================ # @package _global_ defaults: - override /trainer: default # choose trainer from 'configs/trainer/' - override /model: null - override /datamodule: openwebtext # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms. # For GPT2-medium time per global goes from 997ms to 972ms. - override /optimizer: adamw-apex - override /scheduler: linear-warmup - override /callbacks: [default, norm-monitor] - override /metrics: [perplexity, num-tokens] - override /logger: wandb # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters task: _target_: src.tasks.seq.SequenceLMModel seed: 1111 trainer: accelerator: gpu devices: 8 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} max_steps: 400000 val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} check_val_every_n_epoch: null # We don't care about epoch boundary precision: 16 gradient_clip_val: 1.0 strategy: null datamodule: batch_size: 16 # Per GPU batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k max_length: 1024 fault_tolerant: True ddp: ${eval:"${trainer.devices} > 1"} train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} global_batch_size: 512 optimizer: lr: 6e-4 weight_decay: 0.1 optimizer_param_grouping: bias_weight_decay: False normalization_weight_decay: False scheduler: num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}} num_training_steps: ${trainer.max_steps} loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step callbacks: model_checkpoint: monitor: val/loss mode: min save_top_k: 3 save_last: True every_n_train_steps: 1000 dirpath: ${work_dir}/checkpoints/${oc.select:name,''} filename: step_{step} auto_insert_metric_name: False model_checkpoint_progress: _target_: src.callbacks.model_checkpoint.ModelCheckpointMine fault_tolerant: True every_n_train_steps: 50000 save_last: False save_top_k: -1 # Save all the checkpoints dirpath: ${..model_checkpoint.dirpath} filename: progress_step_{step} auto_insert_metric_name: False early_stopping: null ================================================ FILE: training/configs/experiment/owt/gpt2l-flash.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2m-flash.yaml - override /model/gpt2model: gpt2-large # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW. # Still, fairscale is even faster and uses less memory. # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2? # However, fairscale has issues with saving checkpoint (either OOM or very # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the # upstream version of OSS # https://github.com/facebookresearch/fairscale/issues/937 # Pytorch ZeRO as also very slow for saving checkpoints due to # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU. - override /optimizer: adamw-zero # FusedAdam doesn't seem to speed things up here, time per global step # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam. # This could be because each GPU is only doing the optimizer step for 1 / # world_size of the parameters. # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO). # - override /optimizer: adamw-apex-zero # Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB # model: # config: # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} # mlp_checkpoint_lvl: 1 datamodule: # batch_size: 16 batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} trainer: # strategy: null # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"} strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True # TD [2022-08-03] Deepspeed makes the ppl curve go wild # strategy: deepspeed_stage_1 ================================================ FILE: training/configs/experiment/owt/gpt2l-hf.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2m-hf.yaml - override /model/gpt2model: gpt2-large - override /optimizer: adamw-zero datamodule: batch_size: 2 trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True ================================================ FILE: training/configs/experiment/owt/gpt2l.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2m.yaml - override /model/gpt2model: gpt2-large - override /optimizer: adamw-zero datamodule: batch_size: 4 # Per GPU trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True ================================================ FILE: training/configs/experiment/owt/gpt2m-flash.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2s-flash.yaml - override /model/gpt2model: gpt2-medium # Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB # model: # config: # mlp_checkpoint_lvl: 1 datamodule: # batch_size: 32 batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))"} train: optimizer: lr: 1.5e-4 ================================================ FILE: training/configs/experiment/owt/gpt2m-hf.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2s-hf.yaml - override /model/gpt2model: gpt2-medium datamodule: batch_size: 4 train: optimizer: lr: 1.5e-4 ================================================ FILE: training/configs/experiment/owt/gpt2m.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2s.yaml - override /model/gpt2model: gpt2-medium datamodule: batch_size: 8 # Per GPU train: optimizer: lr: 1.5e-4 ================================================ FILE: training/configs/experiment/owt/gpt2s-flash.yaml ================================================ # @package _global_ defaults: - /experiment/owt/base.yaml - override /model: gpt2 - override /model/gpt2model: gpt2-small model: config: # n_positions is already set to ${datamodule.max_length} residual_in_fp32: True use_flash_attn: True fused_bias_fc: True fused_mlp: True fused_dropout_add_ln: True pad_vocab_size_multiple: 8 datamodule: # batch_size: 64 batch_size: ${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"} ================================================ FILE: training/configs/experiment/owt/gpt2s-hf.yaml ================================================ # @package _global_ defaults: - /experiment/owt/base.yaml - override /model: gpt2-hf - override /model/gpt2model: gpt2-small - override /callbacks: [default, norm-monitor, flop-count] datamodule: batch_size: 8 train: # Use the standard torch.nn.CrossEntropyLoss loss_fn: null callbacks: flop_count: input_size: - ${datamodule.max_length} input_dtype: # It's surprisingly hard to get hydra to return torch.long since it's not a callable _target_: torch.__getattribute__ _args_: - long ================================================ FILE: training/configs/experiment/owt/gpt2s.yaml ================================================ # @package _global_ defaults: - /experiment/owt/base.yaml - override /model: gpt2 - override /model/gpt2model: gpt2-small datamodule: batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} ================================================ FILE: training/configs/experiment/owt/gpt2xl-flash.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2l-flash.yaml - override /model/gpt2model: gpt2-xlarge # Can enable mlp_checkpoint_lvl to fit to A100 40GB # model: # config: # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"} # mlp_checkpoint_lvl: 1 datamodule: batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} # With adamw-zero optimizer, on A100 40GB: # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1) # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1) # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1) # With adamw-apex-distributed optimizer: # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1) # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers, # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1) ================================================ FILE: training/configs/experiment/owt/gpt2xl-hf.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2l-hf.yaml - override /model/gpt2model: gpt2-xlarge datamodule: batch_size: 1 ================================================ FILE: training/configs/experiment/owt/gpt2xl.yaml ================================================ # @package _global_ defaults: - /experiment/owt/gpt2m.yaml - override /model/gpt2model: gpt2-xlarge - override /optimizer: adamw-zero datamodule: batch_size: 2 # Per GPU trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True ================================================ FILE: training/configs/experiment/pile/base.yaml ================================================ # @package _global_ defaults: - override /trainer: default # choose trainer from 'configs/trainer/' - override /model: null - override /datamodule: thepile - override /optimizer: adamw-apex # slight speedup (1-2%) over Pytorch AdamW - override /scheduler: cosine-warmup-timm - override /callbacks: [default, norm-monitor] - override /metrics: [perplexity, num-tokens] - override /logger: wandb # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters task: _target_: src.tasks.seq.SequenceLMModel seed: 1111 trainer: accelerator: gpu devices: 8 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} max_steps: 800000 val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}} check_val_every_n_epoch: null # We don't care about epoch boundary precision: bf16 gradient_clip_val: 1.0 strategy: null datamodule: batch_size: 16 # Per GPU batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k max_length: 2048 fault_tolerant: True ddp: ${eval:"${trainer.devices} > 1"} train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} global_batch_size: 256 optimizer: lr: 6e-4 weight_decay: 0.1 optimizer_param_grouping: bias_weight_decay: False normalization_weight_decay: False scheduler: t_in_epochs: False t_initial: 600000 warmup_lr_init: 1e-6 warmup_t: ${eval:0.01 * ${trainer.max_steps}} lr_min: ${eval:0.1 * ${train.optimizer.lr}} loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step callbacks: model_checkpoint: monitor: val/loss mode: min save_top_k: 3 save_last: True every_n_train_steps: 1000 dirpath: ${work_dir}/checkpoints/${oc.select:name,''} filename: step_{step} auto_insert_metric_name: False model_checkpoint_progress: _target_: src.callbacks.model_checkpoint.ModelCheckpointMine # fault_tolerant: True # The .pl_auto_save.ckpt doesn't get saved by all workers every_n_train_steps: 50000 save_last: False save_top_k: -1 # Save all the checkpoints dirpath: ${..model_checkpoint.dirpath} filename: progress_step_{step} auto_insert_metric_name: False early_stopping: null ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-8k.yaml model: config: n_embd: 2560 n_head: 32 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-rotary-8k.yaml model: config: n_embd: 2560 n_head: 20 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-rotary.yaml model: config: n_embd: 2560 n_head: 20 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-hdim128.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash.yaml model: config: n_embd: 2560 n_head: 20 # Headdim 128 is faster than headdim 80 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-rotary-8k.yaml model: config: n_embd: 2560 n_head: 32 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-rotary.yaml model: config: n_embd: 2560 n_head: 32 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-flash.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash.yaml model: config: n_embd: 2560 n_head: 32 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-hf-hdim128.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-hf.yaml model: config: n_embd: 2560 n_head: 128 n_layer: 32 # OOM on A100 80GB even with batch_size = 1 datamodule: batch_size: 1 train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3-2.7B-hf.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-hf.yaml model: config: n_embd: 2560 n_head: 32 n_layer: 32 datamodule: batch_size: 1 train: optimizer: lr: 1.6e-4 ================================================ FILE: training/configs/experiment/pile/gpt3l-flash-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3l-flash.yaml datamodule: max_length: 8192 batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: global_batch_size: 64 ================================================ FILE: training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3l-flash-rotary.yaml trainer: max_steps: 60000 train: scheduler: t_initial: ${trainer.max_steps} ================================================ FILE: training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3l-flash-8k.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3l-flash-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3l-flash.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3l-flash.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash.yaml - override /optimizer: adamw-zero model: config: n_embd: 1536 n_head: 16 n_layer: 24 # mlp_checkpoint_lvl: 1 # To fit batch_size 8 datamodule: batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"} train: optimizer: lr: 2.5e-4 trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True ================================================ FILE: training/configs/experiment/pile/gpt3l-hf.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-hf.yaml model: config: n_embd: 1536 n_head: 16 n_layer: 24 datamodule: batch_size: 2 train: optimizer: lr: 2.5e-4 ================================================ FILE: training/configs/experiment/pile/gpt3m-flash-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3m-flash.yaml datamodule: max_length: 8192 batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} train: global_batch_size: 64 ================================================ FILE: training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3m-flash-rotary.yaml trainer: max_steps: 60000 train: scheduler: t_initial: ${trainer.max_steps} ================================================ FILE: training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3m-flash-8k.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3m-flash-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3m-flash.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3m-flash.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash.yaml - override /model/gpt2model: gpt2-medium # Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB # model: # config: # mlp_checkpoint_lvl: 1 datamodule: batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"} train: optimizer: lr: 3.0e-4 ================================================ FILE: training/configs/experiment/pile/gpt3m-hf.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-hf.yaml - override /model/gpt2model: gpt2-medium datamodule: batch_size: 4 train: optimizer: lr: 3.0e-4 ================================================ FILE: training/configs/experiment/pile/gpt3s-flash-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash.yaml datamodule: max_length: 8192 batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} train: global_batch_size: 64 ================================================ FILE: training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash-rotary.yaml trainer: max_steps: 60000 train: scheduler: t_initial: ${trainer.max_steps} ================================================ FILE: training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash-8k.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3s-flash-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3s-flash.yaml ================================================ # @package _global_ defaults: - /experiment/pile/base.yaml - override /model: gpt2 - override /model/gpt2model: gpt2-small model: config: # n_positions is already set to ${datamodule.max_length} residual_in_fp32: True use_flash_attn: True fused_dropout_add_ln: True fused_mlp: True fused_bias_fc: True pad_vocab_size_multiple: 8 datamodule: batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} ================================================ FILE: training/configs/experiment/pile/gpt3s-hf.yaml ================================================ # @package _global_ defaults: - /experiment/pile/base.yaml - override /model: gpt2-hf - override /model/gpt2model: gpt2-small datamodule: batch_size: 8 train: # Use the standard torch.nn.CrossEntropyLoss loss_fn: null ================================================ FILE: training/configs/experiment/pile/gpt3xl-flash-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash.yaml datamodule: max_length: 8192 batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: global_batch_size: 128 ================================================ FILE: training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-rotary.yaml trainer: max_steps: 60000 train: scheduler: t_initial: ${trainer.max_steps} ================================================ FILE: training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash-8k.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3xl-flash-rotary.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3xl-flash.yaml model: config: max_position_embeddings: 0 # Disable absolute position embedding rotary_emb_fraction: 0.5 ================================================ FILE: training/configs/experiment/pile/gpt3xl-flash.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-flash.yaml - override /optimizer: adamw-zero model: config: n_embd: 2048 n_head: 16 n_layer: 24 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu_mem} < 80 else 8))"} train: global_batch_size: 512 optimizer: lr: 2.0e-4 scheduler: t_initial: 300000 trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True max_steps: 400000 val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} callbacks: model_checkpoint: every_n_train_steps: 1000 model_checkpoint_progress: every_n_train_steps: 12500 fault_tolerant: False # Saving takes too long ================================================ FILE: training/configs/experiment/pile/gpt3xl-hf.yaml ================================================ # @package _global_ defaults: - /experiment/pile/gpt3s-hf.yaml - override /optimizer: adamw-zero model: config: n_embd: 2048 n_head: 16 n_layer: 24 datamodule: batch_size: 2 train: global_batch_size: 512 optimizer: lr: 2.0e-4 scheduler: t_initial: 300000 trainer: strategy: _target_: src.utils.ddp_zero1.DDPStrategyZero1 find_unused_parameters: False gradient_as_bucket_view: True max_steps: 400000 val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} callbacks: model_checkpoint: every_n_train_steps: 1000 model_checkpoint_progress: every_n_train_steps: 12500 fault_tolerant: False # Saving takes too long ================================================ FILE: training/configs/logger/comet.yaml ================================================ # https://www.comet.ml comet: _target_: pytorch_lightning.loggers.comet.CometLogger api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable project_name: "template-tests" experiment_name: ${name} ================================================ FILE: training/configs/logger/csv.yaml ================================================ # csv logger built in lightning csv: _target_: pytorch_lightning.loggers.csv_logs.CSVLogger save_dir: "." name: "csv/" version: ${name} prefix: "" ================================================ FILE: training/configs/logger/many_loggers.yaml ================================================ # train with many loggers at once defaults: # - comet.yaml - csv.yaml # - mlflow.yaml # - neptune.yaml # - tensorboard.yaml - wandb.yaml ================================================ FILE: training/configs/logger/mlflow.yaml ================================================ # https://mlflow.org mlflow: _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger experiment_name: ${name} tracking_uri: null tags: null save_dir: ./mlruns prefix: "" artifact_location: null ================================================ FILE: training/configs/logger/neptune.yaml ================================================ # https://neptune.ai neptune: _target_: pytorch_lightning.loggers.neptune.NeptuneLogger api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable project_name: your_name/template-tests close_after_fit: True offline_mode: False experiment_name: ${name} experiment_id: null prefix: "" ================================================ FILE: training/configs/logger/tensorboard.yaml ================================================ # https://www.tensorflow.org/tensorboard/ tensorboard: _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger save_dir: "tensorboard/" name: "default" version: ${name} log_graph: False default_hp_metric: True prefix: "" ================================================ FILE: training/configs/logger/wandb.yaml ================================================ # https://wandb.ai wandb: _target_: pytorch_lightning.loggers.wandb.WandbLogger project: attention name: ${name} save_dir: "." mode: online # set offline to store all logs only locally id: ${oc.select:name} # pass correct id to resume experiment! # entity: "" # set to name of your wandb team or just remove it log_model: False prefix: "" job_type: "train" group: "" tags: [] ================================================ FILE: training/configs/metrics/acc.yaml ================================================ # @package eval.metrics acc: _target_: src.metrics.accuracy.AccuracyMine ================================================ FILE: training/configs/metrics/acc_ignore_index.yaml ================================================ # @package eval.metrics acc: _target_: torchmetrics.Accuracy ignore_index: -100 ================================================ FILE: training/configs/metrics/acctop5.yaml ================================================ # @package eval.metrics acctop5: _target_: src.metrics.accuracy.AccuracyMine top_k: 5 ================================================ FILE: training/configs/metrics/mse.yaml ================================================ # @package eval.metrics mse: _target_: torchmetrics.MeanSquaredError ================================================ FILE: training/configs/metrics/num-tokens.yaml ================================================ # @package eval.metrics num-tokens: _target_: src.metrics.num_tokens.NumTokens ================================================ FILE: training/configs/metrics/perplexity.yaml ================================================ # @package eval.metrics ppl: _target_: src.metrics.perplexity.Perplexity ================================================ FILE: training/configs/mode/debug.yaml ================================================ # @package _global_ # run in debug mode with: # `python run.py mode=debug` defaults: - override /trainer: debug.yaml debug_mode: True hydra: # sets level of all command line loggers to 'DEBUG' verbose: True # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ # sets level of only chosen command line loggers to 'DEBUG' # verbose: [src.train, src.utils.utils] # sets output paths for all file logs to 'logs/debug/' run: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} sweep: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} subdir: ${hydra.job.num} # disable rich config printing, since it will be already printed by hydra when `verbose: True` print_config: False ================================================ FILE: training/configs/mode/default.yaml ================================================ # @package _global_ # default running mode default_mode: True hydra: # default output paths for all file logs run: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} sweep: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S} subdir: ${hydra.job.num} ================================================ FILE: training/configs/mode/exp.yaml ================================================ # @package _global_ # run in experiment mode with: # `python run.py mode=exp name=experiment_name` experiment_mode: True # allows for custom naming of the experiment name: ??? hydra: # sets output paths for all file logs to `logs/experiment/name' run: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} sweep: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name} subdir: ${hydra.job.num} ================================================ FILE: training/configs/mode/profile.yaml ================================================ # @package _global_ # Run the Pytorch profiler trainer: profiler: _target_: pytorch_lightning.profilers.PyTorchProfiler dirpath: ${hydra.run.dir} schedule: _target_: torch.profiler.schedule wait: 5 warmup: 5 active: 5 use_cuda: True max_steps: 20 logger: wandb: mode: disabled callbacks: model_checkpoint: null model_checkpoint_progress: null early_stopping: null hydra: # sets output paths for all file logs to 'logs/profile/' run: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S} sweep: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S} subdir: ${hydra.job.num} ================================================ FILE: training/configs/mode/smoke.yaml ================================================ # @package _global_ # Smoke test: disable logging and model checkpointing logger: wandb: mode: disabled callbacks: model_checkpoint: null model_checkpoint_progress: null hydra: # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ # sets level of only chosen command line loggers to 'DEBUG' # verbose: [src.train, src.utils.utils] # sets output paths for all file logs to 'logs/debug/' run: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S} sweep: dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S} subdir: ${hydra.job.num} ================================================ FILE: training/configs/model/gpt2-hf.yaml ================================================ defaults: - _self_ - gpt2model: gpt2-small _target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel _recursive_: True config: _target_: transformers.GPT2Config # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml # However, reorder_and_upcast_attn slows things down reorder_and_upcast_attn: false scale_attn_by_inverse_layer_idx: true n_positions: ${datamodule.max_length} ================================================ FILE: training/configs/model/gpt2.yaml ================================================ defaults: - _self_ - gpt2model: gpt2-small _target_: flash_attn.models.gpt.GPTLMHeadModel _recursive_: True config: _target_: transformers.GPT2Config # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml # However, reorder_and_upcast_attn slows things down reorder_and_upcast_attn: false scale_attn_by_inverse_layer_idx: true n_positions: ${datamodule.max_length} ================================================ FILE: training/configs/model/gpt2model/gpt2-large.yaml ================================================ # @package _global_ model: config: n_embd: 1280 n_head: 20 n_layer: 36 ================================================ FILE: training/configs/model/gpt2model/gpt2-medium.yaml ================================================ # @package _global_ model: config: n_embd: 1024 n_head: 16 n_layer: 24 ================================================ FILE: training/configs/model/gpt2model/gpt2-small.yaml ================================================ # @package _global_ model: config: n_embd: 768 n_head: 12 n_layer: 12 ================================================ FILE: training/configs/model/gpt2model/gpt2-xlarge.yaml ================================================ # @package _global_ model: config: n_embd: 1600 n_head: 25 n_layer: 48 ================================================ FILE: training/configs/optimizer/adam.yaml ================================================ # @package train.optimizer _target_: torch.optim.Adam ================================================ FILE: training/configs/optimizer/adamw-apex-distributed.yaml ================================================ # @package train.optimizer _target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam adam_w_mode: True ================================================ FILE: training/configs/optimizer/adamw-apex-zero.yaml ================================================ # @package train.optimizer _target_: torch.distributed.optim.ZeroRedundancyOptimizer _recursive_: True optimizer_class: _target_: apex.optimizers.FusedAdam _partial_: True adam_w_mode: True ================================================ FILE: training/configs/optimizer/adamw-apex.yaml ================================================ # @package train.optimizer _target_: apex.optimizers.FusedAdam adam_w_mode: True ================================================ FILE: training/configs/optimizer/adamw-zero.yaml ================================================ # @package train.optimizer _target_: torch.distributed.optim.ZeroRedundancyOptimizer _recursive_: True optimizer_class: _target_: torch.optim.__getattribute__ _args_: - "AdamW" ================================================ FILE: training/configs/optimizer/adamw.yaml ================================================ # @package train.optimizer _target_: torch.optim.AdamW ================================================ FILE: training/configs/optimizer/fusedlamb-ds.yaml ================================================ # @package train.optimizer _target_: deepspeed.ops.lamb.FusedLamb ================================================ FILE: training/configs/optimizer/fusedlamb.yaml ================================================ # @package train.optimizer _target_: apex.optimizers.FusedLAMB ================================================ FILE: training/configs/optimizer/sgd.yaml ================================================ # @package train.optimizer _target_: torch.optim.SGD ================================================ FILE: training/configs/scheduler/cosine-warmup-timm.yaml ================================================ # @package train.scheduler _target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler ================================================ FILE: training/configs/scheduler/cosine-warmup.yaml ================================================ # @package train.scheduler _target_: transformers.get_cosine_schedule_with_warmup ================================================ FILE: training/configs/scheduler/invsqrt.yaml ================================================ # @package train.scheduler _target_: src.optim.lr_scheduler.InvSqrt num_warmup_steps: ??? ================================================ FILE: training/configs/scheduler/linear-warmup.yaml ================================================ # @package train.scheduler _target_: transformers.get_linear_schedule_with_warmup ================================================ FILE: training/configs/scheduler/multi-step.yaml ================================================ # @package train.scheduler _target_: torch.optim.lr_scheduler.MultiStepLR ================================================ FILE: training/configs/scheduler/plateau.yaml ================================================ # @package _global_ train: scheduler_interval: epoch scheduler_monitor: ??? scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau factor: 0.2 # Decay factor when ReduceLROnPlateau is used patience: 20 min_lr: 0.0 # Minimum learning rate during annealing ================================================ FILE: training/configs/scheduler/poly-warmup.yaml ================================================ # @package train.scheduler _target_: transformers.get_polynomial_decay_schedule_with_warmup ================================================ FILE: training/configs/scheduler/step.yaml ================================================ # @package train.scheduler _target_: torch.optim.lr_scheduler.StepLR step_size: ??? ================================================ FILE: training/configs/task/sequence-model.yaml ================================================ _target_: src.tasks.seq.SequenceModel ================================================ FILE: training/configs/trainer/all_params.yaml ================================================ _target_: pytorch_lightning.Trainer # default values for all trainer parameters checkpoint_callback: True default_root_dir: null gradient_clip_val: 0.0 process_position: 0 num_nodes: 1 num_processes: 1 gpus: null auto_select_gpus: False tpu_cores: null log_gpu_memory: null overfit_batches: 0.0 track_grad_norm: -1 check_val_every_n_epoch: 1 fast_dev_run: False accumulate_grad_batches: 1 max_epochs: 1 min_epochs: 1 max_steps: null min_steps: null limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 val_check_interval: 1.0 flush_logs_every_n_steps: 100 log_every_n_steps: 50 accelerator: null sync_batchnorm: False precision: 32 weights_summary: "top" weights_save_path: null num_sanity_val_steps: 2 truncated_bptt_steps: null resume_from_checkpoint: null profiler: null benchmark: False deterministic: False reload_dataloaders_every_epoch: False auto_lr_find: False replace_sampler_ddp: True terminate_on_nan: False auto_scale_batch_size: False prepare_data_per_node: True plugins: null amp_backend: "native" amp_level: "O2" move_metrics_to_cpu: False ================================================ FILE: training/configs/trainer/ddp.yaml ================================================ defaults: - default.yaml accelerator: gpu devices: 4 strategy: ddp ================================================ FILE: training/configs/trainer/debug.yaml ================================================ defaults: - default.yaml gpus: 0 min_epochs: 1 max_epochs: 2 # prints weights_summary: "full" profiler: null # debugs fast_dev_run: true num_sanity_val_steps: 2 overfit_batches: 0 limit_train_batches: 1.0 limit_val_batches: 1.0 limit_test_batches: 1.0 track_grad_norm: -1 terminate_on_nan: true ================================================ FILE: training/configs/trainer/default.yaml ================================================ _target_: pytorch_lightning.Trainer # set `gpu` to train on GPU, null to train on CPU only accelerator: null min_epochs: 1 max_epochs: 1000 ================================================ FILE: training/run.py ================================================ from typing import Callable import dotenv import hydra from omegaconf import OmegaConf, DictConfig # load environment variables from `.env` file if it exists # recursively searches for `.env` in all folders starting from work dir dotenv.load_dotenv(override=True) OmegaConf.register_new_resolver('eval', eval) OmegaConf.register_new_resolver('div_up', lambda x, y: (x + y - 1) // y) # Delay the evaluation until we have the datamodule # So we want the resolver to yield the same string. OmegaConf.register_new_resolver('datamodule', lambda attr: '${datamodule:' + str(attr) + '}') # Turn on TensorFloat32 import torch.backends torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig: """Only keep keys where fn(key) is True. Support nested DictConfig. """ # Using d.items_ex(resolve=False) instead of d.items() since we want to keep the # ${datamodule:foo} unresolved for now. return DictConfig({k: dictconfig_filter_key(v, fn) if isinstance(v, DictConfig) else v # for k, v in d.items_ex(resolve=False) if fn(k)}) for k, v in d.items() if fn(k)}) @hydra.main(config_path="configs/", config_name="config.yaml") def main(config: DictConfig): # Remove config keys that start with '__'. These are meant to be used only in computing # other entries in the config. config = dictconfig_filter_key(config, lambda k: not k.startswith('__')) # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from src.train import train from src.eval import evaluate from src.utils import utils # A couple of optional utilities: # - disabling python warnings # - forcing debug-friendly configuration # - verifying experiment name is set when running in experiment mode # You can safely get rid of this line if you don't want those utils.extras(config) # Pretty print config using Rich library if config.get("print_config"): utils.print_config(config, resolve=True) # Train model mode = config.get('mode', 'train') if mode not in ['train', 'eval']: raise NotImplementedError(f'mode {mode} not supported') if mode == 'train': return train(config) elif mode == 'eval': return evaluate(config) if __name__ == "__main__": main() ================================================ FILE: training/src/callbacks/__init__.py ================================================ ================================================ FILE: training/src/callbacks/causality_monitor.py ================================================ import pytorch_lightning as pl from pytorch_lightning import Callback from pytorch_lightning.utilities import rank_zero_only import torch from torch.autograd import grad class CausalityMonitor(Callback): r"""Monitor causality of a model by tracking gradient leakage forward in time. In a fully causal model, dy[k]du[s] ~= 0 for all k < s. Args: seq_len (int): Length of the sequence to monitor. input_dim (int): Dimension of the input to monitor. If 0, the callback assumes the task to be language modeling, and skips the embedding layer. If > 0, input_dim is interpreted as the input channel dimension, i.e. D with dummy input of dimension [B, L, D]. Notes: This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute, indicating the primary model to monitor. For LMs, `net` or `s4seq` should be after the embedding layer. """ def __init__(self, seq_len: int = 10, input_dim: int = 0): super().__init__() self.seq_len = seq_len self.input_dim = input_dim @rank_zero_only def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: model = pl_module.model with torch.enable_grad(): if self.input_dim == 0: # [MP] LongTensors cannot have gradients - we start from post # embedding in the LM case input_dim = model.d_model x = torch.randn((2, self.seq_len, input_dim), \ requires_grad=True).to(pl_module.device) # [DF] HACK: we need to get the layer that comes after the embedding if hasattr(model, 'net'): y = model.net(x) else: y = model.s4seq(x) else: x = torch.randn(1, self.seq_len, self.input_dim, \ requires_grad=True).to(pl_module.device) y = model(x) stats = {} for i in range(self.seq_len): # total gradients flowing from y_i to x g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0] g = g[0,i+1:,:].abs().mean() stats[f'stats/causality_{i}'] = g.item() if trainer.loggers is not None: for logger in trainer.loggers: logger.log_metrics(stats, step=trainer.global_step) ================================================ FILE: training/src/callbacks/ema.py ================================================ # Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py # https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py # https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2 # https://github.com/PyTorchLightning/pytorch-lightning/issues/8100 from typing import Dict, Any from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT from src.utils.ema import ExponentialMovingAverage class EMACallback(Callback): """TD [2021-08-31]: saving and loading from checkpoint should work. """ def __init__(self, decay: float, use_num_updates: bool = True): """ decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ super().__init__() self.decay = decay self.use_num_updates = use_num_updates self.ema = None def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): # It's possible that we already loaded EMA from the checkpoint if self.ema is None: self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], decay=self.decay, use_num_updates=self.use_num_updates) # Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it # We only want to update when parameters are changing. # Because of gradient accumulation, this doesn't happen every training step. # https://github.com/PyTorchLightning/pytorch-lightning/issues/11688 def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int, ) -> None: if (batch_idx + 1) % trainer.accumulate_grad_batches == 0: self.ema.update() def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # During the initial validation we don't have self.ema yet if self.ema is not None: self.ema.store() self.ema.copy_to() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.ema is not None: self.ema.restore() def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.ema is not None: self.ema.store() self.ema.copy_to() def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.ema is not None: self.ema.restore() def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> Dict[str, Any]: return self.ema.state_dict() def on_load_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: if self.ema is None: self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], decay=self.decay, use_num_updates=self.use_num_updates) self.ema.load_state_dict(checkpoint) ================================================ FILE: training/src/callbacks/flop_count.py ================================================ # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py from typing import Any, List, Sequence import torch from pytorch_lightning import Callback, Trainer, LightningModule from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict from src.utils.flops import has_deepspeed_profiling, has_fvcore_profiling from src.utils.flops import profile_deepspeed, profile_fvcore class FlopCount(Callback): """Counter the number of FLOPs used by the model """ def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'], input_size: tuple = (3, 224, 224), input_dtype=torch.float32, device=None): if not isinstance(profilers, Sequence): profilers = [profilers] if any(p not in ['fvcore', 'deepspeed'] for p in profilers): raise NotImplementedError('Only support fvcore and deepspeed profilers') if 'fvcore' in profilers and not has_fvcore_profiling: raise ImportError('fvcore is not installed. Install it by running `pip install fvcore`') elif 'deepspeed' in profilers and not has_deepspeed_profiling: raise ImportError('deepspeed is not installed') super().__init__() self.profilers = profilers self.input_size = tuple(input_size) self.input_dtype = input_dtype self.device = device @rank_zero_only def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: if 'fvcore' in self.profilers: _, macs, _, acts = profile_fvcore(pl_module.to(self.device), input_size=self.input_size, input_dtype=self.input_dtype, detailed=True) trainer.logger.log_hyperparams({'GMACs': macs * 1e-9, 'MActs': acts * 1e-6}) if 'deepspeed' in self.profilers: macs, _= profile_deepspeed(pl_module.to(self.device), input_size=self.input_size, input_dtype=self.input_dtype, detailed=True) if 'fvcore' not in self.profilers: # fvcore's MACs seem more accurate trainer.logger.log_hyperparams({'GMACs': macs * 1e-9}) ================================================ FILE: training/src/callbacks/gpu_affinity.py ================================================ import torch from pytorch_lightning import Callback, Trainer, LightningModule import logging log = logging.getLogger(__name__) # We want a logger for each process, not just the rank 0 def l2_promote(): import ctypes _libcudart = ctypes.CDLL('libcudart.so') # Set device limit on the current device # cudaLimitMaxL2FetchGranularity = 0x05 pValue = ctypes.cast((ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int)) _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) assert pValue.contents.value == 128 def set_affinity(trainer): try: from src.utils.gpu_affinity import set_affinity nproc_per_node = torch.cuda.device_count() affinity = set_affinity(trainer.local_rank, nproc_per_node, 'socket_unique_continuous') log.info(f'{trainer.local_rank}: thread affinity: {affinity}') # TD [2022-05-07] Somehow calling this causes GPU 0 to allocate extra ~800MB of memory per # number of GPUs (e.g., 6.4GB of extra memory in a 8-GPU setup). H/t Dan. # l2_promote() except: pass class GpuAffinity(Callback): """Set GPU affinity and increase the L2 fetch granularity. Adapted from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/Transformer-XL """ def setup(self, trainer: Trainer, pl_module: LightningModule, stage=None) -> None: set_affinity(trainer) ================================================ FILE: training/src/callbacks/loss_scale_monitor.py ================================================ # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/lr_monitor.py. from typing import Any from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy class LossScaleMonitor(Callback): """Monitor the loss scale for AMP (fp16). """ # Use on_before_optimizer_step instead of on_train_batch_start since there might be # gradient accumulation and we only care about the loss scale when it could change (i.e., # optimizer.step). @rank_zero_only def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwargs: Any) -> None: if not trainer._logger_connector.should_update_logs: return stats = {} if isinstance(trainer.strategy, DeepSpeedStrategy): stats = {'scalar/scale': trainer.model.optimizer.loss_scale} if hasattr(trainer, 'precision_plugin') and hasattr(trainer.precision_plugin, 'scaler'): scaler = trainer.precision_plugin.scaler if scaler is not None: stats = { 'scaler/scale': scaler.get_scale(), 'scaler/growth_tracker': scaler._get_growth_tracker(), } if stats and trainer.loggers is not None: for logger in trainer.loggers: logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) ================================================ FILE: training/src/callbacks/model_checkpoint.py ================================================ # Adapted from https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/callbacks/fault_tolerance.py from typing import Any from pathlib import Path import pytorch_lightning as pl class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint): def __init__(self, *args, fault_tolerant=False, **kwargs): super().__init__(*args, **kwargs) self.fault_tolerant = fault_tolerant def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: if self.fault_tolerant: # overwrite if necessary trainer.save_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) # def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: # if self.fault_tolerant: # trainer.strategy.remove_checkpoint(str(Path(self.dirpath) / '.pl_auto_save.ckpt')) # TD [2022-07-17] I was trying to make resuming from standard checkpoint fault-tolerant. # However, when it resumes it's off by 1 iteration. My attempt to fix it in seq.py (below) didn't work. # So I decided to just copy _FaultToleranceCheckpoint and just save on_exception. # def on_save_checkpoint(self, checkpoint): # # TD [2022-07-12] The "completed" counter is off by 1 so when it resumes # # it's off by 1 iteration. However, the data is still off by 1 iteration, probably # # because the dataloader_state_dict['counter'] is off by @batch_size, and idk how # # to fix it cleanly. # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] += 1 # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] += 1 # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] += 1 # checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['dataloader_state_dict'][0]['state'][0]['num_batches_fetched'] += 1 ================================================ FILE: training/src/callbacks/norm_monitor.py ================================================ # Inspired by https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/grads.py # However, they compute grad at every iteration (I think), and the .item() calls incur a lot of overhead # (6-7% slow down on GPT-2 small). Instead we only compute for iterations where we need to log, and don't # call .item() explicitly. from typing import Any from collections import OrderedDict from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy import torch import torch.nn as nn try: from apex.contrib.layer_norm import FastLayerNorm except ImportError: FastLayerNorm = None class NormMonitor(Callback): """Monitor the scales of weights and gradients. """ def __init__(self, layer_norm_only: bool = False): super().__init__() self.layer_norm_only = layer_norm_only # Use on_before_optimizer_step instead of on_train_batch_start since there might be # gradient accumulation and we only care about scale when it could change (i.e., optimizer.step). @rank_zero_only def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args: Any, **kwargs: Any) -> None: if not trainer._logger_connector.should_update_logs: return model = pl_module.model named_parameters = {} if self.layer_norm_only: ln_modules = (nn.LayerNorm, nn.Embedding) if FastLayerNorm is not None: ln_modules += (FastLayerNorm,) for mn, m in model.named_modules(): if isinstance(m, ln_modules): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name named_parameters[fpn] = p else: named_parameters = dict(model.named_parameters()) if isinstance(trainer.strategy, DeepSpeedStrategy): loss_scale = trainer.model.optimizer.loss_scale else: loss_scale = 1.0 stats = {} param_l1_norm, grad_l1_norm = [], [] for param_name, param in named_parameters.items(): param_abs = param.abs() param_abs_mean = param_abs.mean(dtype=torch.float32) stats[f'stats/{param_name}_max'] = param_abs.max() stats[f'stats/{param_name}_mean'] = param_abs_mean param_l1_norm.append(param_abs_mean * param.numel()) if param.grad is not None: # If using AMP, gradient is already unscaled by the AMP loss scaler at this point # https://github.com/Lightning-AI/lightning/pull/9606 # However, if using DeepSpeed, we need to scale it ourselves param_grad_abs = param.grad.abs() param_grad_abs_mean = param_grad_abs.mean(dtype=torch.float32) / loss_scale stats[f'stats/{param_name}_grad_max'] = param_grad_abs.max() / loss_scale stats[f'stats/{param_name}_grad_mean'] = param_grad_abs_mean grad_l1_norm.append(param_grad_abs_mean * param.grad.numel()) stats['total_param_l1_norm'] = torch.stack(param_l1_norm).sum() if grad_l1_norm: stats['total_grad_l1_norm'] = torch.stack(grad_l1_norm).sum() # Sort by params name stats = OrderedDict(sorted(stats.items())) if trainer.loggers is not None: for logger in trainer.loggers: logger.log_metrics(stats, step=trainer.fit_loop.epoch_loop._batches_that_stepped) ================================================ FILE: training/src/callbacks/params_log.py ================================================ from typing import Any from pytorch_lightning import Callback, Trainer, LightningModule from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict class ParamsLog(Callback): """Log the number of parameters of the model """ def __init__(self, total_params_log: bool = True, trainable_params_log: bool = True, non_trainable_params_log: bool = True): super().__init__() self._log_stats = AttributeDict( { 'total_params_log': total_params_log, 'trainable_params_log': trainable_params_log, 'non_trainable_params_log': non_trainable_params_log, } ) @rank_zero_only def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: logs = {} if self._log_stats.total_params_log: logs["model/params_total"] = sum(p.numel() for p in pl_module.parameters()) if self._log_stats.trainable_params_log: logs["model/params_trainable"] = sum(p.numel() for p in pl_module.parameters() if p.requires_grad) if self._log_stats.non_trainable_params_log: logs["model/params_not_trainable"] = sum(p.numel() for p in pl_module.parameters() if not p.requires_grad) if trainer.logger is not None: trainer.logger.log_hyperparams(logs) ================================================ FILE: training/src/callbacks/speed_monitor.py ================================================ # Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor # We only need the speed monitoring, not the GPU monitoring import time from typing import Any from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.types import STEP_OUTPUT class SpeedMonitor(Callback): """Monitor the speed of each step and each epoch. """ def __init__(self, intra_step_time: bool = True, inter_step_time: bool = True, epoch_time: bool = True, verbose=False): super().__init__() self._log_stats = AttributeDict( { 'intra_step_time': intra_step_time, 'inter_step_time': inter_step_time, 'epoch_time': epoch_time, } ) self.verbose = verbose def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._snap_epoch_time = None def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._snap_intra_step_time = None self._snap_inter_step_time = None self._snap_epoch_time = time.time() def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._snap_inter_step_time = None def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._snap_inter_step_time = None @rank_zero_only def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, ) -> None: if self._log_stats.intra_step_time: self._snap_intra_step_time = time.time() if not trainer._logger_connector.should_update_logs: return logs = {} if self._log_stats.inter_step_time and self._snap_inter_step_time: # First log at beginning of second step logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 if trainer.logger is not None: trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int, ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time: pl_module.print(f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}") if not trainer._logger_connector.should_update_logs: return logs = {} if self._log_stats.intra_step_time and self._snap_intra_step_time: logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 if trainer.logger is not None: trainer.logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",) -> None: logs = {} if self._log_stats.epoch_time and self._snap_epoch_time: logs["time/epoch (s)"] = time.time() - self._snap_epoch_time if trainer.logger is not None: trainer.logger.log_metrics(logs, step=trainer.global_step) ================================================ FILE: training/src/callbacks/wandb_callbacks.py ================================================ import subprocess from pathlib import Path from typing import List import matplotlib.pyplot as plt import seaborn as sn import torch import wandb from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only from sklearn import metrics from sklearn.metrics import f1_score, precision_score, recall_score def get_wandb_logger(trainer: Trainer) -> WandbLogger: """Safely get Weights&Biases logger from Trainer.""" if trainer.fast_dev_run: raise Exception( "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." ) if isinstance(trainer.logger, WandbLogger): return trainer.logger if isinstance(trainer.logger, LoggerCollection): for logger in trainer.logger: if isinstance(logger, WandbLogger): return logger raise Exception( "You are using wandb related callback, but WandbLogger was not found for some reason..." ) class WatchModel(Callback): """Make wandb watch model at the beginning of the run.""" def __init__(self, log: str = "gradients", log_freq: int = 100): self.log = log self.log_freq = log_freq @rank_zero_only def on_train_start(self, trainer, pl_module): logger = get_wandb_logger(trainer=trainer) logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) class UploadCodeAsArtifact(Callback): """Upload all code files to wandb as an artifact, at the beginning of the run.""" def __init__(self, code_dir: str, use_git: bool = True): """ Args: code_dir: the code directory use_git: if using git, then upload all files that are not ignored by git. if not using git, then upload all '*.py' file """ self.code_dir = code_dir self.use_git = use_git @rank_zero_only def on_train_start(self, trainer, pl_module): logger = get_wandb_logger(trainer=trainer) experiment = logger.experiment code = wandb.Artifact("project-source", type="code") if self.use_git: # get .git folder # https://alexwlchan.net/2020/11/a-python-function-to-ignore-a-path-with-git-info-exclude/ git_dir_path = Path( subprocess.check_output(["git", "rev-parse", "--git-dir"]).strip().decode("utf8") ).resolve() for path in Path(self.code_dir).resolve().rglob("*"): if ( path.is_file() # ignore files in .git and not str(path).startswith(str(git_dir_path)) # noqa: W503 # ignore files ignored by git and ( # noqa: W503 subprocess.run(["git", "check-ignore", "-q", str(path)]).returncode == 1 ) ): code.add_file(str(path), name=str(path.relative_to(self.code_dir))) else: for path in Path(self.code_dir).resolve().rglob("*.py"): code.add_file(str(path), name=str(path.relative_to(self.code_dir))) experiment.log_artifact(code) class UploadCheckpointsAsArtifact(Callback): """Upload checkpoints to wandb as an artifact, at the end of run.""" def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): self.ckpt_dir = ckpt_dir self.upload_best_only = upload_best_only @rank_zero_only def on_keyboard_interrupt(self, trainer, pl_module): self.on_train_end(trainer, pl_module) @rank_zero_only def on_train_end(self, trainer, pl_module): logger = get_wandb_logger(trainer=trainer) experiment = logger.experiment ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") if self.upload_best_only: ckpts.add_file(trainer.checkpoint_callback.best_model_path) else: for path in Path(self.ckpt_dir).rglob("*.ckpt"): ckpts.add_file(str(path)) experiment.log_artifact(ckpts) class LogConfusionMatrix(Callback): """Generate confusion matrix every epoch and send it to wandb. Expects validation step to return predictions and targets. """ def __init__(self): self.preds = [] self.targets = [] self.ready = True def on_sanity_check_start(self, trainer, pl_module) -> None: self.ready = False def on_sanity_check_end(self, trainer, pl_module): """Start executing this callback only after all validation sanity checks end.""" self.ready = True def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): """Gather data from single batch.""" if self.ready: self.preds.append(outputs["preds"]) self.targets.append(outputs["targets"]) def on_validation_epoch_end(self, trainer, pl_module): """Generate confusion matrix.""" if self.ready: logger = get_wandb_logger(trainer) experiment = logger.experiment preds = torch.cat(self.preds).cpu().numpy() targets = torch.cat(self.targets).cpu().numpy() confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) # set figure size plt.figure(figsize=(14, 8)) # set labels size sn.set(font_scale=1.4) # set font size sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") # names should be uniqe or else charts from different experiments in wandb will overlap experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) # according to wandb docs this should also work but it crashes # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) # reset plot plt.clf() self.preds.clear() self.targets.clear() class LogF1PrecRecHeatmap(Callback): """Generate f1, precision, recall heatmap every epoch and send it to wandb. Expects validation step to return predictions and targets. """ def __init__(self, class_names: List[str] = None): self.preds = [] self.targets = [] self.ready = True def on_sanity_check_start(self, trainer, pl_module): self.ready = False def on_sanity_check_end(self, trainer, pl_module): """Start executing this callback only after all validation sanity checks end.""" self.ready = True def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): """Gather data from single batch.""" if self.ready: self.preds.append(outputs["preds"]) self.targets.append(outputs["targets"]) def on_validation_epoch_end(self, trainer, pl_module): """Generate f1, precision and recall heatmap.""" if self.ready: logger = get_wandb_logger(trainer=trainer) experiment = logger.experiment preds = torch.cat(self.preds).cpu().numpy() targets = torch.cat(self.targets).cpu().numpy() f1 = f1_score(targets, preds, average=None) r = recall_score(targets, preds, average=None) p = precision_score(targets, preds, average=None) data = [f1, p, r] # set figure size plt.figure(figsize=(14, 3)) # set labels size sn.set(font_scale=1.2) # set font size sn.heatmap( data, annot=True, annot_kws={"size": 10}, fmt=".3f", yticklabels=["F1", "Precision", "Recall"], ) # names should be uniqe or else charts from different experiments in wandb will overlap experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) # reset plot plt.clf() self.preds.clear() self.targets.clear() class LogImagePredictions(Callback): """Logs a validation batch and their predictions to wandb. Example adapted from: https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY """ def __init__(self, num_samples: int = 8): super().__init__() self.num_samples = num_samples self.ready = True def on_sanity_check_start(self, trainer, pl_module): self.ready = False def on_sanity_check_end(self, trainer, pl_module): """Start executing this callback only after all validation sanity checks end.""" self.ready = True def on_validation_epoch_end(self, trainer, pl_module): if self.ready: logger = get_wandb_logger(trainer=trainer) experiment = logger.experiment # get a validation batch from the validation dat loader val_samples = next(iter(trainer.datamodule.val_dataloader())) val_imgs, val_labels = val_samples # run the batch through the network val_imgs = val_imgs.to(device=pl_module.device) logits = pl_module(val_imgs) preds = torch.argmax(logits, dim=-1) # log the images as wandb Image experiment.log( { f"Images/{experiment.name}": [ wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") for x, pred, y in zip( val_imgs[: self.num_samples], preds[: self.num_samples], val_labels[: self.num_samples], ) ] } ) ================================================ FILE: training/src/datamodules/datasets/detokenizer.py ================================================ # Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py # Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py """ Handle detokenization for different dataset for zero-shot LM evaluation. """ import re def wikitext_detokenize(string: str) -> str: """ Wikitext is whitespace tokenized and we remove these whitespaces. Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py """ # Contractions string = string.replace("s '", "s'") string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) # Number Separators string = string.replace(" @-@ ", "-") string = string.replace(" @,@ ", ",") string = string.replace(" @.@ ", ".") # Punctuation string = string.replace(" : ", ": ") string = string.replace(" ; ", "; ") string = string.replace(" . ", ". ") string = string.replace(" ! ", "! ") string = string.replace(" ? ", "? ") string = string.replace(" , ", ", ") # Double Brackets string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) # Miscellaneous string = string.replace("= = = =", "====") string = string.replace("= = =", "===") string = string.replace("= =", "==") string = string.replace(" " + chr(176) + " ", chr(176)) string = string.replace(" \n", "\n") string = string.replace("\n ", "\n") string = string.replace(" N ", " 1 ") string = string.replace(" 's", "'s") return string # Set Registry for Various Datasets DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} ================================================ FILE: training/src/datamodules/datasets/lm_dataset.py ================================================ # Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py # Except we don't pad the last block and don't use overlapping eval # And we return both the input and the target import math import numpy as np import torch class LMDataset(torch.utils.data.Dataset): def __init__(self, tokens, seq_len, drop_last=True): """tokens should be a numpy array """ self.seq_len = seq_len ntokens = len(tokens) if drop_last: ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 self.ntokens = ntokens # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, # and slicing would load it to memory. self.tokens = tokens self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) def __len__(self): return self.total_sequences def __getitem__(self, idx): start_idx = idx * self.seq_len seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) return data[:-1], data[1:].clone() ================================================ FILE: training/src/datamodules/fault_tolerant_sampler.py ================================================ # Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 from typing import Iterator import math import torch from torch.utils.data import RandomSampler, DistributedSampler class RandomFaultTolerantSampler(RandomSampler): def __init__(self, *args, generator=None, **kwargs): # generator = torch.Generator().manual_seed(seed) # super().__init__(*args, generator=generator, **kwargs) # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, # which should be reproducible if pl.seed_everything was called before hand. # This means that changing the seed of the experiment will also change the # sampling order. if generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator().manual_seed(seed) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 # self.start_counter = 0 self.restarting = False def state_dict(self): return {"random_state": self.state, "counter": self.counter} def load_state_dict(self, state_dict): self.generator.set_state(state_dict.get("random_state")) self.counter = state_dict["counter"] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. # def __len__(self): # # We need a separate self.start_counter because PL seems to call len repeatedly. # # If we use len(self.data_source) - self.counter then PL will think the epoch ends # # when we're only half way through. # return len(self.data_source) - self.start_counter def __iter__(self) -> Iterator[int]: n = len(self.data_source) self.state = self.generator.get_state() indices = torch.randperm(n, generator=self.generator).tolist() if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False # self.start_counter = self.counter for index in indices: self.counter += 1 yield index self.counter = 0 # self.start_counter = self.counter class FaultTolerantDistributedSampler(DistributedSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 # self.start_counter = 0 self.restarting = False def state_dict(self): return {"epoch": self.epoch, "counter": self.counter} def load_state_dict(self, state_dict): self.epoch = state_dict["epoch"] self.counter = state_dict["counter"] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. # def __len__(self) -> int: # return self.num_samples - self.start_counter def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False # self.start_counter = self.counter for index in indices: self.counter += 1 yield index self.counter = 0 # self.start_counter = self.counter ================================================ FILE: training/src/datamodules/imagenet.py ================================================ # Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py import os from pathlib import Path from typing import Any, List, Union, Callable, Optional import torch from torch.utils.data import Dataset, DataLoader, SequentialSampler from torch.utils.data.dataloader import default_collate from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import LightningDataModule from torchvision import transforms from torchvision.datasets import ImageFolder class DictDataset(Dataset): def __init__(self, dataset_dict, length=None): """dataset_dict: dictionary mapping from index to batch length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but with 8 GPUs the dataset_dict would only have 125 items. """ super().__init__() self.dataset_dict = dataset_dict self.length = length or len(self.dataset_dict) def __getitem__(self, index): return self.dataset_dict[index] def __len__(self): return self.length # From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10 def imagenet_normalization(): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) class ImagenetDataModule(LightningDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png :width: 400 :alt: Imagenet Specs: - 1000 classes - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) Imagenet train, val and test dataloaders. The train set is the imagenet train. The val set is taken from the train set with `num_imgs_per_val_class` images per class. For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set. The test set is the official imagenet validation set. Example:: from pl_bolts.datamodules import ImagenetDataModule dm = ImagenetDataModule(IMAGENET_PATH) model = LitModel() Trainer().fit(model, datamodule=dm) """ name = "imagenet" def __init__( self, data_dir: str, image_size: int = 224, train_transforms=None, val_transforms=None, test_transforms=None, img_dtype='float32', # Using str since OmegaConf doesn't support non-primitive type cache_val_dataset=False, mixup: Optional[Callable] = None, num_aug_repeats: int = 0, num_workers: int = 0, batch_size: int = 32, batch_size_eval: Optional[int] = None, shuffle: bool = True, pin_memory: bool = True, drop_last: bool = False, *args: Any, **kwargs: Any, ) -> None: """ Args: data_dir: path to the imagenet dataset file num_imgs_per_val_class: how many images per class for the validation set image_size: final image size num_workers: how many data workers batch_size: batch_size shuffle: If true shuffles the data every epoch pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before returning them drop_last: If true drops the last incomplete batch """ super().__init__(*args, **kwargs) self.image_size = image_size self.train_transforms = train_transforms self.val_transforms = val_transforms self.test_transforms = test_transforms assert img_dtype in ['float32', 'float16', 'bfloat16'] self.img_dtype = torch.__getattribute__(img_dtype) self.cache_val_dataset = cache_val_dataset self.mixup = mixup self.num_aug_repeats = num_aug_repeats self.dims = (3, self.image_size, self.image_size) self.data_dir = Path(data_dir).expanduser() self.num_workers = num_workers self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last @property def num_classes(self) -> int: """ Return: 1000 """ return 1000 def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError( f"a {split} Imagenet split was not found in {data_dir}," f" make sure the folder contains a subfolder named {split}" ) def prepare_data(self) -> None: """This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. .. warning:: Please download imagenet on your own first. """ self._verify_splits(self.data_dir, "train") self._verify_splits(self.data_dir, "val") def setup(self, stage: Optional[str] = None) -> None: """Creates train, val, and test dataset.""" if stage == "fit" or stage is None: train_transforms = (self.train_transform() if self.train_transforms is None else self.train_transforms) val_transforms = (self.val_transform() if self.val_transforms is None else self.val_transforms) if self.img_dtype is not torch.float32: assert isinstance(train_transforms, transforms.Compose) assert isinstance(val_transforms, transforms.Compose) convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype)) train_transforms.transforms.append(convert_dtype) val_transforms.transforms.append(convert_dtype) self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms) self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms) if stage == "test" or stage is None: test_transforms = (self.val_transform() if self.test_transforms is None else self.test_transforms) if self.img_dtype is not torch.float32: assert isinstance(test_transforms, transforms.Compose) convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype)) test_transforms.transforms.append(convert_dtype) self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms) def train_transform(self) -> Callable: """The standard imagenet transforms. .. code-block:: python transforms.Compose([ transforms.RandomResizedCrop(self.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) """ preprocessing = transforms.Compose( [ transforms.RandomResizedCrop(self.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), imagenet_normalization(), ] ) return preprocessing def val_transform(self) -> Callable: """The standard imagenet transforms for validation. .. code-block:: python transforms.Compose([ transforms.Resize(self.image_size + 32), transforms.CenterCrop(self.image_size), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) """ preprocessing = transforms.Compose( [ transforms.Resize(self.image_size + 32), transforms.CenterCrop(self.image_size), transforms.ToTensor(), imagenet_normalization(), ] ) return preprocessing def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """ The train dataloader """ if self.num_aug_repeats == 0: shuffle = self.shuffle sampler = None else: shuffle = False from timm.data.distributed_sampler import RepeatAugSampler sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats) return self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, mixup=self.mixup, sampler=sampler) def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ # If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to # construct the DistributedSampler ourselves. if not self.cache_val_dataset: sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last) if self.num_aug_repeats != 0 else None) return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, sampler=sampler) else: print('Caching val dataset') sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1 else DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)) indices = list(iter(sampler)) loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler, num_workers=self.num_workers, drop_last=self.drop_last) batches = list(loader) assert len(batches) == len(indices) self.dataset_val = DictDataset(dict(zip(indices, batches)), length=len(self.dataset_val)) sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last) if self.num_aug_repeats != 0 else None) return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval, sampler=sampler) def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last) if self.num_aug_repeats != 0 else None) return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler) def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, mixup: Optional[Callable] = None, sampler=None) -> DataLoader: collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None else default_collate) return DataLoader( dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=shuffle, sampler=sampler, num_workers=self.num_workers, drop_last=self.drop_last, pin_memory=self.pin_memory, persistent_workers=True ) class Imagenet21kPDataModule(ImagenetDataModule): """ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K """ @property def num_classes(self) -> int: """ Return: 10450 """ return 10450 ================================================ FILE: training/src/datamodules/language_modeling_hf.py ================================================ # Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py from itertools import chain from pathlib import Path import pickle from typing import Any, List, Union import subprocess import mmap from multiprocessing.shared_memory import SharedMemory import numpy as np import torch from torch.utils.data.dataloader import DataLoader, Dataset from transformers import AutoTokenizer from datasets import load_dataset from pytorch_lightning import LightningDataModule from src.datamodules.datasets.lm_dataset import LMDataset from src.datamodules.fault_tolerant_sampler import RandomFaultTolerantSampler from src.datamodules.fault_tolerant_sampler import FaultTolerantDistributedSampler from src.datamodules.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY from src.utils.utils import get_logger logger = get_logger() # https://github.com/numpy/numpy/issues/18294 class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array def __new__(cls, input_array, shm=None): obj = np.asarray(input_array).view(cls) obj.shm = shm return obj def __array_finalize__(self, obj): if obj is None: return self.shm = getattr(obj, 'shm', None) class LMDataModule(LightningDataModule): def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, fast_forward_epochs=None, fast_forward_batches=None, use_shmem=True): super().__init__() self.dataset_name = dataset_name self.dataset_config_name = dataset_config_name self.tokenizer_name = tokenizer_name self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() self.max_length = max_length self.val_ratio = val_ratio self.val_split_seed = val_split_seed self.val_only = val_only self.add_eos = add_eos self.detokenize = detokenize self.batch_size = batch_size self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size self.num_workers = num_workers self.shuffle = shuffle self.pin_memory = pin_memory self.drop_last = drop_last if fault_tolerant: assert self.shuffle self.fault_tolerant = fault_tolerant if ddp: assert fault_tolerant self.ddp = ddp self.fast_forward_epochs = fast_forward_epochs self.fast_forward_batches = fast_forward_batches if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: assert ddp and fault_tolerant self.use_shmem = use_shmem if self.use_shmem: assert cache_dir is not None def prepare_data(self): if self.cache_dir is None: # Just download the dataset load_dataset(self.dataset_name, self.dataset_config_name) else: # Process the dataset and save it self.process_dataset() def setup(self, stage=None): if stage == 'test' and hasattr(self, 'dataset_test'): return concat_ids, self.tokenizer = self.process_dataset() self.vocab_size = len(self.tokenizer) # Create all splits self.dataset_train, self.dataset_val, self.dataset_test = [ LMDataset(concat_ids[split], seq_len=self.max_length) for split in ['train', 'validation', 'test'] ] def process_dataset(self): cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name if cache_dir is not None: if cache_dir.is_dir(): return self._load_from_cache(cache_dir) raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py if 'validation' not in raw_datasets: assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" raw_datasets = raw_datasets["train"].train_test_split( test_size=self.val_ratio, seed=self.val_split_seed, shuffle=True # Otherwise test will be at the end of the dataset ) raw_datasets['validation'] = raw_datasets['test'] if self.val_only: # Should only be used for evaluation, not for training raw_datasets['train'] = raw_datasets['validation'] # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse # (GPT2-small val ppl after 10 epochs ~22 -> ~25) # However, it's useful for zero-shot transfer from Openwebtext, # as after detokenization it's closer to Openwebtext's format. # https://github.com/stanford-crfm/mistral/issues/12 if self.detokenize: if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] raw_datasets = raw_datasets.map( lambda example: {'text': detokenizer(example['text'])}, num_proc=max(self.num_workers, 1), desc='Running detokenizer on dataset' ) tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) # Preprocessing the datasets. # First we tokenize all the texts. column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends # with '\n', and there are no other '\n' in the examples. # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) # Add EOS token to the end of the text if the text is not empty # https://github.com/stanford-crfm/mistral/issues/91 # https://github.com/stanford-crfm/mistral/pull/98 if self.add_eos: add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) else: tokenize = lambda example: tokenizer(example[text_column_name]) # tokenized_datasets = raw_datasets.map( # tokenize, # batched=True, # num_proc=max(self.num_workers, 1), # remove_columns=column_names, # desc="Running tokenizer on dataset", # ) dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 def tokenize_concat(examples): # We just need 'input_ids', not 'attention_mask' (since it's all 1) input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) # Need to return a list since we're doing batched processing return {'input_ids': [input_ids], 'len': [len(input_ids)]} tokenized_datasets = raw_datasets.map( tokenize_concat, batched=True, num_proc=max(self.num_workers, 1), remove_columns=column_names, desc="Running tokenizer on dataset", ) if self.use_shmem: # Concatenate all input_ids into an array in shared memory def write_ids_to_shm(example, shm_name, array_len): shm = SharedMemory(name=shm_name) shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) start_idx = example['len_offset'] - len(example['input_ids']) shm_arr[start_idx:example['len_offset']] = example['input_ids'] shm.close() concat_ids = {} for name, ds in tokenized_datasets.items(): tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) array_len = tokenized_datasets[name][-1]['len_offset'] shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) shm_name = shm.name tokenized_datasets[name].map( write_ids_to_shm, fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, batched=False, num_proc=max(self.num_workers, 1), desc="Concatenating examples", ) shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) # We need to keep a reference to the shared memory, otherwise it gets garbage-collected # when it goes out of scope, and that memory is gone. # https://github.com/numpy/numpy/issues/18294 concat_ids[name] = SHMArray(shm_arr, shm=shm) else: # Use disk concat_ids = {} assert cache_dir is not None cache_dir.mkdir(parents=True, exist_ok=True) def write_ids_to_disk(example, filename): with open(filename, 'r+b') as f: mm = mmap.mmap(f.fileno(), 0) start_idx = example['len_offset'] - len(example['input_ids']) array_len = len(example['input_ids']) arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, offset=np.dtype(dtype).itemsize * start_idx) arr[:] = example['input_ids'] mm.flush() for name, ds in tokenized_datasets.items(): tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) array_len = tokenized_datasets[name][-1]['len_offset'] filename = cache_dir / f'{name}.bin' # Need to create the file with this specific size first # https://ostechnix.com/create-files-certain-size-linux/ subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), str(filename)], check=True) tokenized_datasets[name].map( write_ids_to_disk, fn_kwargs={'filename': filename}, batched=False, num_proc=max(self.num_workers, 1), desc="Concatenating examples", ) concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) if cache_dir is not None: self._save_to_cache(concat_ids, tokenizer, cache_dir) if not self.use_shmem: for name in concat_ids: Path(cache_dir / f'{name}.bin').unlink() return concat_ids, tokenizer def _save_to_cache(self, concat_ids, tokenizer, cache_dir): cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f'Saving to cache at {str(cache_dir)}') for k, v in concat_ids.items(): np.save(cache_dir / f'{k}.npy', v) with open(cache_dir / 'tokenizer.pkl', 'wb') as f: pickle.dump(tokenizer, f) def _load_from_cache(self, cache_dir): assert cache_dir.is_dir() logger.info(f'Load from cache at {str(cache_dir)}') concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') for split in ['train', 'validation', 'test']} with open(cache_dir / 'tokenizer.pkl', 'rb') as f: tokenizer = pickle.load(f) return concat_ids, tokenizer @property def _cache_dir_name(self): return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """ The train dataloader """ if self.shuffle and self.fault_tolerant: shuffle = False sampler = (FaultTolerantDistributedSampler(self.dataset_train) if self.ddp else RandomFaultTolerantSampler(self.dataset_train)) # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now # We assume that it's being resumed with the same number of GPUs if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: sampler.load_state_dict({ 'epoch': self.fast_forward_epochs, 'counter': self.fast_forward_batches * self.batch_size }) else: shuffle = self.shuffle sampler = None return self._data_loader(self.dataset_train, batch_size=self.batch_size, shuffle=shuffle, sampler=sampler) def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, sampler=None) -> DataLoader: return DataLoader( dataset, batch_size=batch_size, num_workers=1, # Data is already in memory, we don't need many workers shuffle=shuffle, sampler=sampler, drop_last=self.drop_last, pin_memory=self.pin_memory, # persistent_workers=True ) def load_state_dict(self, checkpoint): if self.fault_tolerant: self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration # behind, so we're using the optimizer's progress. This is set correctly in seq.py. self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] # At this point the train loader hasn't been constructed yet ================================================ FILE: training/src/datamodules/timm_mixup.py ================================================ import torch from timm.data import Mixup from timm.data.mixup import mixup_target class TimmMixup(Mixup): """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. """ def __call__(self, x, target): if self.mode == 'elem': lam = self._mix_elem(x) elif self.mode == 'pair': # We move the assert from the beginning of the function to here assert len(x) % 2 == 0, 'Batch size should be even when using this' lam = self._mix_pair(x) else: lam = self._mix_batch(x) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) return x, target ================================================ FILE: training/src/distributed/ddp_comm_hooks.py ================================================ # Adapted from https://pytorch.org/docs/stable/_modules/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.html # We divide by world_size first before converting to fp16, so it's safer. from typing import Any, Callable import torch import torch.distributed as dist def fp16_compress_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: """ This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) and then divides it by the process group size. It allreduces those ``float16`` gradient tensors. Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). Example:: >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() # Divide first before converting to fp16 # Use out argument to fuse the division and the conversion. compressed_tensor = torch.div(bucket.buffer(), world_size, out=torch.empty_like(bucket.buffer(), dtype=torch.float16)) fut = dist.all_reduce( compressed_tensor, group=group_to_use, async_op=True ).get_future() def decompress(fut): decompressed_tensor = bucket.buffer() # Decompress in place to reduce the peak memory. # See: https://github.com/pytorch/pytorch/issues/45968 decompressed_tensor.copy_(fut.value()[0]) return decompressed_tensor # TODO: maybe have a backoff strategy: check if the buffer has inf / NaN, in that case # resend with fp32? return fut.then(decompress) ================================================ FILE: training/src/eval.py ================================================ from typing import List, Optional from pathlib import Path import torch import hydra from omegaconf import OmegaConf, DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, Trainer, seed_everything, ) from pytorch_lightning.loggers import LightningLoggerBase from src.utils import utils log = utils.get_logger(__name__) def remove_prefix(text: str, prefix: str): if text.startswith(prefix): return text[len(prefix) :] return text # or whatever def load_checkpoint(path, device='cpu'): path = Path(path).expanduser() if path.is_dir(): path /= 'last.ckpt' # dst = f'cuda:{torch.cuda.current_device()}' log.info(f'Loading checkpoint from {str(path)}') state_dict = torch.load(path, map_location=device) # T2T-ViT checkpoint is nested in the key 'state_dict_ema' if state_dict.keys() == {'state_dict_ema'}: state_dict = state_dict['state_dict_ema'] # Swin checkpoint is nested in the key 'model' if state_dict.keys() == {'model'}: state_dict = state_dict['model'] # Lightning checkpoint contains extra stuff, we only want the model state dict if 'pytorch-lightning_version' in state_dict: state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()} return state_dict def evaluate(config: DictConfig) -> None: """Example of inference with trained model. It loads trained image classification model from checkpoint. Then it loads example image and predicts its label. """ # load model from checkpoint # model __init__ parameters will be loaded from ckpt automatically # you can also pass some parameter explicitly to override it # We want to add fields to config so need to call OmegaConf.set_struct OmegaConf.set_struct(config, False) # load model checkpoint_type = config.eval.get('checkpoint_type', 'pytorch') if checkpoint_type not in ['lightning', 'pytorch']: raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported') if checkpoint_type == 'lightning': cls = hydra.utils.get_class(config.task._target_) model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt) elif checkpoint_type == 'pytorch': model_cfg = config.model_pretrained if 'model_pretrained' in config else None trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, model_cfg=model_cfg, _recursive_=False) if 'ckpt' in config.eval: load_return = trained_model.model.load_state_dict( load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False ) log.info(load_return) if 'model_pretrained' in config: ... else: model = trained_model datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) # datamodule: LightningDataModule = model._datamodule datamodule.prepare_data() datamodule.setup() # print model hyperparameters log.info(f'Model hyperparameters: {model.hparams}') # Init Lightning callbacks callbacks: List[Callback] = [] if "callbacks" in config: for _, cb_conf in config["callbacks"].items(): if cb_conf is not None and "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>") callbacks.append(hydra.utils.instantiate(cb_conf)) # Init Lightning loggers logger: List[LightningLoggerBase] = [] if "logger" in config: for _, lg_conf in config["logger"].items(): if lg_conf is not None and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") logger.append(hydra.utils.instantiate(lg_conf)) # Init Lightning trainer log.info(f"Instantiating trainer <{config.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" ) # Evaluate the model log.info("Starting evaluation!") if config.eval.get('run_val', True): trainer.validate(model=model, datamodule=datamodule) if config.eval.get('run_test', True): trainer.test(model=model, datamodule=datamodule) # Make sure everything closed properly log.info("Finalizing!") utils.finish( config=config, model=model, datamodule=datamodule, trainer=trainer, callbacks=callbacks, logger=logger, ) ================================================ FILE: training/src/metrics/accuracy.py ================================================ import torch from torch import Tensor from torchmetrics import Metric, Accuracy class AccuracyMine(Accuracy): """Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. """ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) ================================================ FILE: training/src/metrics/num_tokens.py ================================================ from typing import Any, Dict, Optional import torch from torch import Tensor from torchmetrics import Metric class NumTokens(Metric): """Keep track of how many tokens we've seen. """ # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch # of the next epoch. # Right now the hack is that we override reset(), which would mess up the forward method. # We then override forward to do the right thing. is_differentiable = False higher_is_better = False full_state_update = False count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", persistent=True) # We want the count to be saved to state-dict def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore self.count += target.numel() def compute(self) -> Tensor: return self.count def reset(self): count = self.count super().reset() self.count = count # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: """forward computation using single call to `update` to calculate the metric value on the current batch and accumulate global state. This can be done when the global metric state is a sinple reduction of batch states. """ self.update(*args, **kwargs) return self.compute() ================================================ FILE: training/src/metrics/perplexity.py ================================================ # Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py # But we pass in the loss to avoid recomputation from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import Tensor from torchmetrics import Metric try: from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = torch.nn.CrossEntropyLoss __all__ = ['Perplexity'] class Perplexity(Metric): r""" Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits per word a model needs to represent the sample. Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: >>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) >>> target[0, 6:] = -100 >>> metric = Perplexity(ignore_index=-100) >>> metric(preds, target) tensor(5.2545) """ is_differentiable = True higher_is_better = False full_state_update = False total_log_probs: Tensor count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") self.loss_fn = CrossEntropyLoss() def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore """Compute and store intermediate statistics for Perplexity. Args: preds: Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. target: Ground truth values with a shape [batch_size, seq_len]. """ count = target.numel() if loss is None: loss = self.loss_fn(preds, target) self.total_log_probs += loss.double() * count self.count += count def compute(self) -> Tensor: """Compute the Perplexity. Returns: Perplexity """ return torch.exp(self.total_log_probs / self.count) ================================================ FILE: training/src/models/modules/seq_common.py ================================================ import math from functools import partial from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.utils import _pair import hydra from einops import reduce, rearrange def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=True): if pooling_mode not in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN']: raise NotImplementedError(f'pooling_mode must be MEAN, SUM, CLS, LAST, FLATTEN') if pooling_mode in ['MEAN', 'SUM']: if key_padding_mask is not None: mask = rearrange(~key_padding_mask.bool_matrix, 'b s -> b s 1' if batch_first else 'b s -> s b 1') x = x.masked_fill(mask, 0) s = reduce(x, 'b s ... -> b ...' if batch_first else 's b ... -> b ...', 'sum') if pooling_mode == 'SUM': return s else: if key_padding_mask is None: return s / x.shape[1 if batch_first else 0] else: lengths = rearrange(key_padding_mask._lengths, 'b -> b 1') return s / lengths elif pooling_mode == 'CLS': return x[:, 0] if batch_first else x[0] elif pooling_mode == 'LAST': if key_padding_mask is None: return x[:, -1] if batch_first else x[-1] else: lengths = key_padding_mask._lengths if batch_first: batch_size = x.shape[0] return x[torch.arange(batch_size, device=x.device), lengths - 1] else: batch_size = x.shape[1] return x[lengths - 1, torch.arange(batch_size, device=x.device)] elif pooling_mode == 'FLATTEN': return rearrange(x, 'b ... -> b (...)' if batch_first else 's b ... -> b (s ...)') class ClassificationHeadLinear(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, d_model, num_classes, pooling_mode='MEAN', batch_first=False, **kwargs): super().__init__() assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported' self.pooling_mode = pooling_mode self.batch_first = batch_first self.out_proj = nn.Linear(d_model, num_classes) def forward(self, hidden_states, key_padding_mask=None, **kwargs): """ hidden_states: (B, S, D) if batch_first else (S, B, D) """ hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode, key_padding_mask=key_padding_mask, batch_first=self.batch_first) hidden_states = self.out_proj(hidden_states) return hidden_states # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/models/reformer/modeling_reformer.py class ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN', batch_first=False): super().__init__() assert pooling_mode in ['MEAN', 'SUM', 'CLS', 'LAST', 'FLATTEN'], 'pooling_mode not supported' self.pooling_mode = pooling_mode self.batch_first = batch_first self.dense = nn.Linear(d_model, d_inner) self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(d_inner, num_classes) def forward(self, hidden_states, key_padding_mask=None, **kwargs): """ hidden_states: (B, S, D) if batch_first else (S, B, D) """ hidden_states = pooling(hidden_states, pooling_mode=self.pooling_mode, key_padding_mask=key_padding_mask, batch_first=self.batch_first) hidden_states = self.dropout(hidden_states) hidden_states = self.dense(hidden_states) # Huggingface uses tanh instead of relu hidden_states = torch.relu(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states class ClassificationHeadDual(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling_mode='MEAN', batch_first=False, interaction='NLI'): super().__init__() assert pooling_mode in ['MEAN', 'SUM', 'CLS'], 'pooling_mode not supported' assert interaction in [None, 'NLI'], 'interaction not supported' self.pooling_mode = pooling_mode self.batch_first = batch_first self.interaction = interaction self.dense = nn.Linear(d_model * (4 if self.interaction == 'NLI' else 2), d_inner) self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(d_inner, num_classes) def forward(self, hidden_states1, hidden_states2, key_padding_mask1=None, key_padding_mask2=None, **kwargs): """ hidden_states: (B, S, D) if batch_first else (S, B, D) """ x1 = pooling(hidden_states1, pooling_mode=self.pooling_mode, key_padding_mask=key_padding_mask1, batch_first=self.batch_first) x2 = pooling(hidden_states2, pooling_mode=self.pooling_mode, key_padding_mask=key_padding_mask2, batch_first=self.batch_first) hidden_states = (torch.cat([x1, x2, x1 * x2, x1 - x2], dim=-1) if self.interaction == 'NLI' else torch.cat([x1, x2], dim=-1)) hidden_states = self.dropout(hidden_states) hidden_states = self.dense(hidden_states) # Huggingface uses tanh instead of relu hidden_states = torch.relu(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.out_proj(hidden_states) return hidden_states class LMHead(nn.Module): def __init__(self, d_model, num_classes, batch_first=True, bias=True): super().__init__() self.lm_head = nn.Linear(d_model, num_classes, bias=bias) def forward(self, hidden_states, **kwargs): """ hidden_states: (B, S, D) if batch_first else (S, B, D) """ CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) return CausalLMOutput(self.lm_head(hidden_states)) def sinusoidal_init_(tensor): """ tensor: (max_len, d_model) """ max_len, d_model = tensor.shape position = rearrange(torch.arange(0.0, max_len), 's -> s 1') div_term = torch.exp(-math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model) tensor[:, 0::2] = torch.sin(position * div_term) tensor[:, 1::2] = torch.cos(position * div_term) return tensor # Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py class PositionalEncoding(nn.Module): r"""Inject some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings, so that the two can be summed. Here, we use sine and cosine functions of different frequencies. .. math:: \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) \text{where pos is the word position and i is the embed idx) Args: d_model: the embed dim (required). dropout: the dropout value (default=0.1). max_len: the max. length of the incoming sequence (default=5000). Examples: >>> pos_encoder = PositionalEncoding(d_model) """ def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=False, initializer=None): super().__init__() self.batch_first = batch_first self.dropout = nn.Dropout(p=dropout) pe = torch.empty(max_len, d_model) if initializer is None: sinusoidal_init_(pe) pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d') self.register_buffer('pe', pe) else: hydra.utils.call(initializer, pe) pe = rearrange(pe, 's d -> 1 s d' if self.batch_first else 's d -> s 1 d') self.pe = nn.Parameter(pe) def forward(self, x): r"""Inputs of forward function Args: x: the sequence fed to the positional encoder model (required). Shape: x: [sequence length, batch size, embed dim] if not batch_first else [B, S, D] output: [sequence length, batch size, embed dim] if not batch_first else [B, S, D] Examples: >>> output = pos_encoder(x) """ x = x + (self.pe[:, :x.size(1)] if self.batch_first else self.pe[:x.size(0)]) return self.dropout(x) # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, act_fn=None, drop=0., device=None, dtype=None): """TD [2021-10-27] act_fn takes precedence over act_layer if set. This is to support Pytorch 1.10 Transformer interface that construct the activation *function*, not the activation *layer*. """ factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features drop_probs = _pair(drop) self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.act = act_layer() if act_fn is None else act_fn self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class MlpBig(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, act_fn=None, drop=0., device=None, dtype=None): """Copied from Mlp above. If num_layers > 2, add more Mlp layers, doubling each time. """ factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features cur_hidden_features = hidden_features layers = [] for _ in range(4): layers.append(nn.Linear(in_features, cur_hidden_features, **factory_kwargs)) layers.append(act_layer()) layers.append(nn.Dropout(drop)) in_features = cur_hidden_features cur_hidden_features *= 2 layers.append(nn.Linear(in_features, out_features, **factory_kwargs)) layers.append(nn.Dropout(drop)) self.fwd = nn.Sequential(*layers) def forward(self, x): return self.fwd(x) class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features // 2, out_features) self.drop = nn.Dropout(drop) def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 fc1_mid = self.fc1.bias.shape[0] // 2 nn.init.ones_(self.fc1.bias[fc1_mid:]) nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) def forward(self, x): x = self.fc1(x) x, gates = x.chunk(2, dim=-1) x = x * self.act(gates) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GatedMlp(nn.Module): """ MLP as used in gMLP """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, gate_layer=None, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.gate(x) x = self.fc2(x) x = self.drop(x) return x class ConvMlp(nn.Module): """ MLP using 1x1 convs that keeps spatial dims """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.norm(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) return x ================================================ FILE: training/src/optim/param_grouping.py ================================================ import inspect import torch.nn as nn import hydra try: from apex.contrib.layer_norm import FastLayerNorm except ImportError: FastLayerNorm = None from src.models.modules.seq_common import PositionalEncoding def group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_decay=False, normalization_weight_decay=False): """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for normalization parameters if normalization_weight_decay==False """ # Get the weight decay from the config, or from the default value of the optimizer constructor # if it's not specified in the config. if 'weight_decay' in optimizer_cfg: weight_decay = optimizer_cfg.weight_decay else: # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) if 'weight_decay' in signature.parameters: weight_decay = signature.parameters['weight_decay'].default if weight_decay is inspect.Parameter.empty: weight_decay = 0.0 else: weight_decay = 0.0 # If none of the parameters have weight decay anyway, and there are no parameters with special # optimization params if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): return model.parameters() skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') else set()) # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() special = set() whitelist_weight_modules = (nn.Linear, ) blacklist_weight_modules = (nn.Embedding, PositionalEncoding) if not normalization_weight_decay: blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, nn.GroupNorm, nn.SyncBatchNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm, nn.LocalResponseNorm) if FastLayerNorm is not None: blacklist_weight_modules += (FastLayerNorm,) param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} for mn, m in model.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name # In case of parameter sharing, some parameters show up here but are not in # param_dict.keys() if not p.requires_grad or fpn not in param_dict: continue # frozen weights if hasattr(p, '_optim'): special.add(fpn) elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): no_decay.add(fpn) elif getattr(p, '_no_weight_decay', False): no_decay.add(fpn) elif not bias_weight_decay and pn.endswith('bias'): no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) decay |= (param_dict.keys() - no_decay - special) # validate that we considered every parameter inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" if weight_decay == 0.0 or not no_decay: param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], "weight_decay": weight_decay}] else: # We need sorted(list()) so that the order is deterministic. Otherwise when we resume # the order could change and resume will fail. [H/t Albert] param_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] # Add parameters with special hyperparameters # Unique dicts hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] for hp in hps: params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] param_groups.append({"params": params, **hp}) return param_groups ================================================ FILE: training/src/optim/timm_lr_scheduler.py ================================================ import torch from torch.optim import Optimizer from timm.scheduler import CosineLRScheduler # We need to subclass torch.optim.lr_scheduler._LRScheduler, or Pytorch-lightning will complain class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. It supports resuming as well. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on whether we're using the scheduler every # epoch or every step. # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set # scheduler interval to "step", then the learning rate update will be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) ================================================ FILE: training/src/tasks/seq.py ================================================ from typing import Any, List import inspect import torch import hydra from pytorch_lightning import LightningModule, LightningDataModule from torchmetrics import MetricCollection from einops import rearrange from omegaconf import OmegaConf from src.utils.utils import get_logger from src.optim.param_grouping import group_parameters_for_optimizer from src.utils.checkpoint import load_checkpoint logger = get_logger(__name__) class SequenceModel(LightningModule): def __init__(self, cfg, model_cfg=None): """If model_cfg is passed, it will take precedence over cfg.model """ super().__init__() # this line ensures params passed to LightningModule will be saved to ckpt # it also allows to access params with 'self.hparams' attribute self.save_hyperparameters(cfg) self.cfg = cfg self.model_cfg = model_cfg or self.cfg.model self.instantiate_datamodule() self.instantiate_model() self.warmstart() self.instantiate_loss() self.instantiate_metrics() def instantiate_datamodule(self): logger.info(f"Instantiating datamodule <{self.cfg.datamodule._target_}>") # Calling this self.datamodule will mess with PL since it also assigns self.datamodule self._datamodule: LightningDataModule = hydra.utils.instantiate(self.cfg.datamodule) self._datamodule.prepare_data() self._datamodule.setup() OmegaConf.clear_resolver('datamodule') OmegaConf.register_new_resolver('datamodule', lambda attr: getattr(self._datamodule, attr)) def instantiate_model(self): # if hasattr(self._datamodule, 'num_classes'): # self.model_cfg.num_classes = self._datamodule.num_classes # if (hasattr(self._datamodule, 'vocab_size') # and self.model_cfg.get('embedding_cfg', None) is not None # and self.model_cfg.embedding_cfg._target_ == "torch.nn.Embedding"): # self.model_cfg.embedding_cfg.num_embeddings = self._datamodule.vocab_size logger.info(f"Instantiating model <{self.model_cfg._target_}>") recursive = getattr(self.model_cfg, '_recursive_', False) self.model = hydra.utils.instantiate(self.model_cfg, _recursive_=recursive) def instantiate_loss(self): loss_fn_cfg = self.cfg.train.get('loss_fn') if loss_fn_cfg is None: loss_fn_cfg = {'_target_': 'torch.nn.CrossEntropyLoss'} self.loss_fn = hydra.utils.instantiate(loss_fn_cfg) loss_fn_val_cfg = self.cfg.train.get('loss_fn_val', loss_fn_cfg) self.loss_fn_val = hydra.utils.instantiate(loss_fn_val_cfg) def instantiate_metrics(self): # use separate metric instance for train, val and test step # to ensure a proper reduction over the epoch if 'eval' in self.cfg and 'metrics' in self.cfg.eval: metrics_cfg = self.cfg.eval.metrics else: metrics_cfg = {'acc': {'_target_': 'torchmetrics.Accuracy'}} metrics = MetricCollection({name: hydra.utils.instantiate(cfg) for name, cfg in metrics_cfg.items()}) self.train_metrics = metrics.clone(prefix='train/') self.val_metrics = metrics.clone(prefix='val/') self.test_metrics = metrics.clone(prefix='test/') def warmstart(self): if self.cfg.train.get('warmstart', None) is not None: logger.info(f"Warm-starting with weights from {self.cfg.train.warmstart.path}") strict = self.cfg.train.warmstart.get('strict', True) state_dict = load_checkpoint(self.cfg.train.warmstart.path) if self.cfg.train.warmstart.get('post_process', None) is not None: state_dict = hydra.utils.instantiate(self.cfg.train.warmstart.post_process, state_dict) load_return = self.model.load_state_dict(state_dict, strict=False) logger.info(load_return) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def step(self, batch: Any, is_train=True): try: x, y, lengths = batch except ValueError: x, y = batch lengths = None output = self.forward(x) if lengths is None else self.forward(x, lengths=lengths) loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) return loss, output, y def shared_step(self, batch: Any, batch_idx: int, phase='train'): loss, output, targets = self.step(batch, is_train=(phase == 'train')) metrics = getattr(self, f'{phase}_metrics') metrics(output, targets) log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, prog_bar=False, sync_dist=True) # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training # We need to log the Metrics object, not the metric result, since otherwise # pytorch-lightning will use torch.mean to reduce it. # This would be wrong for perplexity, for example. self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) return {"loss": loss, "output": output, "targets": targets} def training_step(self, batch: Any, batch_idx: int): return self.shared_step(batch, batch_idx, phase='train') def validation_step(self, batch: Any, batch_idx: int): return self.shared_step(batch, batch_idx, phase='val') def test_step(self, batch: Any, batch_idx: int): return self.shared_step(batch, batch_idx, phase='test') def configure_optimizers(self): if 'optimizer_param_grouping' in self.cfg.train: # Set zero weight decay for some params parameters = group_parameters_for_optimizer(self.model, self.cfg.train.optimizer, **self.cfg.train.optimizer_param_grouping) else: # parameters = self.model.parameters() parameters = self.parameters() # [21-09-08] AG: this will train task specific parameters such as Retrieval head for AAN optimizer = hydra.utils.instantiate(self.cfg.train.optimizer, parameters) # Log optimizer info for i, g in enumerate(optimizer.param_groups): ntensors = len(g['params']) nparams = sum(p.numel() for p in g['params']) hparams = {k: v for k, v in g.items() if k != 'params'} logger.info(f'Optimizer group {i}: {ntensors} tensors, {nparams} parameters, {hparams}') if 'scheduler' not in self.cfg.train: return optimizer else: # lr_scheduler should be called either every step (default) or every epoch lr_scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer) return [optimizer], {'scheduler': lr_scheduler, 'interval': self.cfg.train.get('scheduler_interval', 'step'), 'monitor': self.cfg.train.get('scheduler_monitor', 'val/loss')} def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): # https://pytorch-lightning.readthedocs.io/en/latest/guides/speed.html#set-grads-to-none # TD [2022-04-30]: DeepSpeed optimizer uses the kwarg set_grad_to_none instead of set_to_none if 'set_to_none' in inspect.signature(optimizer.zero_grad).parameters: optimizer.zero_grad(set_to_none=True) else: optimizer.zero_grad() def on_save_checkpoint(self, checkpoint): # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration # behind, so we're using the optimizer's progress. checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches # _batches_that_stepped tracks the number of global steps, not the number # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here. checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = checkpoint['loops']['fit_loop']['epoch_loop.batch_loop.optimizer_loop.optim_progress']['optimizer']['step']['total']['completed'] class SequenceLMModel(SequenceModel): def step(self, batch: Any, is_train=True): x, y = batch output = self.forward(x).logits output = rearrange(output, '... C -> (...) C') y = rearrange(y, '... -> (...)') loss = self.loss_fn(output, y) if is_train else self.loss_fn_val(output, y) return loss, output, y def shared_step(self, batch: Any, batch_idx: int, phase='train'): loss, output, targets = self.step(batch, is_train=(phase == 'train')) # Passing the loss to the perplexity metrics to avoid recomputation metrics = getattr(self, f'{phase}_metrics') metrics(output, targets, loss=loss) log_on_step = 'eval' in self.cfg and self.cfg.eval.get('log_on_step', False) and phase == 'train' self.log(f"{phase}/loss", loss, on_step=log_on_step, on_epoch=True, prog_bar=False, sync_dist=True) # https://pytorch-lightning.readthedocs.io/en/stable/visualize/logging_advanced.html#enable-metrics-for-distributed-training # We need to log the Metrics object, not the metric result, since otherwise # pytorch-lightning will use torch.mean to reduce it. # This would be wrong for perplexity, for example. self.log_dict(metrics, on_step=log_on_step, on_epoch=True, prog_bar=True, sync_dist=True) return {"loss": loss, "output": output, "targets": targets} ================================================ FILE: training/src/train.py ================================================ from typing import List, Optional, Sequence from pathlib import Path import hydra from omegaconf import OmegaConf, DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, Trainer, seed_everything, ) from pytorch_lightning.loggers import LightningLoggerBase from src.utils import utils log = utils.get_logger(__name__) def last_modification_time(path): """Including files / directory 1-level below the path """ path = Path(path) if path.is_file(): return path.stat().st_mtime elif path.is_dir(): return max(child.stat().st_mtime for child in path.iterdir()) else: return None def train(config: DictConfig) -> Optional[float]: """Contains training pipeline. Instantiates all PyTorch Lightning objects from config. Args: config (DictConfig): Configuration composed by Hydra. Returns: Optional[float]: Metric score for hyperparameter optimization. """ # Set seed for random number generators in pytorch, numpy and python.random if config.get("seed"): seed_everything(config.seed, workers=True) # We want to add fields to config so need to call OmegaConf.set_struct OmegaConf.set_struct(config, False) # Init lightning model model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False) datamodule: LightningDataModule = model._datamodule # Init lightning callbacks callbacks: List[Callback] = [] if "callbacks" in config: for _, cb_conf in config.callbacks.items(): if cb_conf is not None and "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>") callbacks.append(hydra.utils.instantiate(cb_conf)) # Init lightning loggers logger: List[LightningLoggerBase] = [] if "logger" in config: for _, lg_conf in config.logger.items(): if lg_conf is not None and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") logger.append(hydra.utils.instantiate(lg_conf)) ckpt_cfg = {} if config.get('resume'): try: checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath) if checkpoint_path.is_dir(): last_ckpt = checkpoint_path / 'last.ckpt' autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt' if not (last_ckpt.exists() or autosave_ckpt.exists()): raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt") if ((not last_ckpt.exists()) or (autosave_ckpt.exists() and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))): # autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt')) checkpoint_path = autosave_ckpt else: checkpoint_path = last_ckpt # DeepSpeed's checkpoint is a directory, not a file if checkpoint_path.is_file() or checkpoint_path.is_dir(): ckpt_cfg = {'ckpt_path': str(checkpoint_path)} else: log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch') except (KeyError, FileNotFoundError): pass # Configure ddp automatically n_devices = config.trainer.get('devices', 1) if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example n_devices = len(n_devices) if n_devices > 1 and config.trainer.get('strategy', None) is None: config.trainer.strategy = dict( _target_='pytorch_lightning.strategies.DDPStrategy', find_unused_parameters=False, gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations ) # Init lightning trainer log.info(f"Instantiating trainer <{config.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( config.trainer, callbacks=callbacks, logger=logger) # Train the model log.info("Starting training!") trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg) # Evaluate model on test set, using the best model achieved during training if config.get("test_after_training") and not config.trainer.get("fast_dev_run"): log.info("Starting testing!") trainer.test(model=model, datamodule=datamodule) # Make sure everything closed properly log.info("Finalizing!") utils.finish( config=config, model=model, datamodule=datamodule, trainer=trainer, callbacks=callbacks, logger=logger, ) # Print path to best checkpoint if not config.trainer.get("fast_dev_run"): log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}") # Return metric score for hyperparameter optimization optimized_metric = config.get("optimized_metric") if optimized_metric: return trainer.callback_metrics[optimized_metric] ================================================ FILE: training/src/utils/checkpoint.py ================================================ import re from pathlib import Path import torch import math from einops import rearrange def load_checkpoint(path, device='cpu'): path = Path(path).expanduser() is_deepspeed = False if path.is_dir(): # DeepSpeed checkpoint is_deepspeed = True latest_path = path / 'latest' if latest_path.is_file(): with open(latest_path, 'r') as fd: tag = fd.read().strip() else: raise ValueError(f"Unable to find 'latest' file at {latest_path}") path /= f'{tag}/mp_rank_00_model_states.pt' state_dict = torch.load(path, map_location=device) if is_deepspeed: state_dict = state_dict['module'] # Replace the names of some of the submodules def key_mapping(key): return re.sub(r'^module.model.', '', key) state_dict = {key_mapping(k): v for k, v in state_dict.items()} return state_dict def blockdiag_to_dense_mlp_bert(state_dict): from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight names = {name for name in state_dict if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight', name)} for name in names: state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name]) return state_dict def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False): orig_emb = state_dict['state_dict'][pos_embedding_name] assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length' reps = [1 for i in orig_emb.shape] reps[1] = out_seqlen // orig_emb.shape[1] if interleave: assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths' assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths' assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square' emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1])) emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2) new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d') state_dict['state_dict'][pos_embedding_name] = new_emb else: state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps) ret = remove_model_prefix(state_dict) # # HACK: this is a hack for block-sparse flash attention ret = { k: v for k, v in ret.items() if not k.endswith('inner_attn.layout') } return ret def remove_model_prefix(state_dict): # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix for key in list(state_dict['state_dict'].keys()): if key.startswith('model.'): new_key = key[len('model.'):] state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key) # HACK: something is wrong with the state dict being loaded... return state_dict['state_dict'] ================================================ FILE: training/src/utils/ddp_zero1.py ================================================ # Meant to work with Pytorch's ZeroRedundancyOptimizer from typing import Any, Callable, Dict, List, Optional, Union from pathlib import Path import torch from torch.optim.optimizer import Optimizer from torch.distributed.optim import ZeroRedundancyOptimizer from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.core.optimizer import LightningOptimizer try: # pytorch_lightning <= 1.7 from pytorch_lightning.utilities.types import _PATH except ImportError: # pytorch_lightning >= 1.8 try: from lightning_lite.utilities.types import _PATH except ImportError: # pytorch_lightning >= 1.9 from lightning_fabric.utilities.types import _PATH # Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get # the local state dict to avoid synchronization across GPUs. # https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131 def get_zero_optimizer_state_dict_local(optimizer, global_rank): optimizer._check_overlap_initialized() # Sync the exposed `param_groups` attributes to the local optimizer in # case they have been updated optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups) local_state_dict = optimizer.optim.state_dict() state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict() # Update the global optimizer state with local state information, # factoring in the translation from local to global indexing rank = global_rank # TODO: recursive copy to device local_param_groups = local_state_dict["param_groups"] global_param_groups = optimizer._partition_parameters()[rank] assert len(local_param_groups) == len(global_param_groups), \ "Mismatch between number of local and global parameter groups" for local_param_group, global_param_group in zip(local_param_groups, global_param_groups): # `local_param_group` stores local indices, while # `global_param_group` stores the tensors directly local_param_indices = local_param_group["params"] global_params = global_param_group["params"] assert len(local_param_indices) == len(global_params), \ "Mismatch between number of local and global parameters in parameter group" for local_param_index, global_param in zip(local_param_indices, global_params): # Update the global parameter state, if any if local_param_index in local_state_dict["state"]: global_param_index = optimizer._param_to_index[global_param] state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index] # Sort the parameters in the state state_dict["state"] = dict(sorted(state_dict["state"].items())) return state_dict class DDPStrategyZero1(DDPStrategy): """To use ZeroRedundancyOptimizer, we need to shard the optimizer states when saving/loading checkpoints. """ strategy_name = "ddp_zero1" def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if isinstance(optimizer, ZeroRedundancyOptimizer): return get_zero_optimizer_state_dict_local(optimizer, self.global_rank) else: return optimizer.state_dict() def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ filepath = Path(filepath) filepath.mkdir(parents=True, exist_ok=True) local_optimizer_states = checkpoint.pop('optimizer_states') if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', storage_options=storage_options) self.checkpoint_io.save_checkpoint(local_optimizer_states, filepath / f'{self.global_rank:03d}_optim_states.pt', storage_options=storage_options) def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() checkpoint_path = Path(checkpoint_path) if checkpoint_path.is_file(): return super().load_checkpoint(self, str(checkpoint_path)) else: assert checkpoint_path.is_dir() global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt') global_states['optimizer_states'] = local_optimizer_states return global_states ================================================ FILE: training/src/utils/ddp_zero2.py ================================================ # Meant to work with Apex's DistributeFusedAdam from typing import Any, Callable, Dict, List, Optional, Union from pathlib import Path import types import torch from torch.optim.optimizer import Optimizer from torch.optim import LBFGS from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.plugins.precision import PrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.utilities.exceptions import MisconfigurationException try: # pytorch_lightning <= 1.7 from pytorch_lightning.utilities.types import _PATH except ImportError: # pytorch_lightning >= 1.8 try: from lightning_lite.utilities.types import _PATH except ImportError: # pytorch_lightning >= 1.9 from lightning_fabric.utilities.types import _PATH class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): def optimizer_step( # type: ignore[override] self, model: "pl.LightningModule", optimizer, optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, ) -> Any: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler return NativeMixedPrecisionPlugin.optimizer_step( self, optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs ) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) closure_result = closure() # HACK: we don't call self.scaler.unscale_ here. This is because DistributedFusedAdam # optimizer internally takes the scale into account. # If we call unscale_ here, it would be equivalent to unscaling the gradients twice. # Not unscaling has the side-effect that the NormMonitor callback will report the # gradient norm to be much larger than reality. # # `unscale` after the closure is executed but before the `on_before_optimizer_step` hook. # self.scaler.unscale_(optimizer) # This will call gradient clipping self._after_closure(model, optimizer, optimizer_idx) skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() return step_output return closure_result def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" # DistributedFusedAdam wants list, not generator # Gradients have not be scaled, so we need to scale up the clip_val if self.scaler is not None: clip_val *= self.scaler.get_scale() return optimizer.clip_grad_norm(clip_val) class DDPStrategyZero2(DDPStrategy): """To use Apex's DistributedFusedAdam, we need to shard the optimizer states when saving/loading checkpoints. """ strategy_name = "ddp_zero2" def __init__( self, *args, precision_plugin: Optional[PrecisionPlugin] = DistAdamNativeMixedPrecisionPlugin, # precision_plugin: Optional[PrecisionPlugin] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: super().__init__( *args, precision_plugin=precision_plugin, **kwargs ) @property def precision_plugin(self) -> PrecisionPlugin: return self._precision_plugin if self._precision_plugin is not None else PrecisionPlugin() @precision_plugin.setter def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]) -> None: self._precision_plugin = precision_plugin # https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance self._precision_plugin.optimizer_step = types.MethodType( DistAdamNativeMixedPrecisionPlugin.optimizer_step, self._precision_plugin ) self._precision_plugin.clip_grad_by_norm = types.MethodType( DistAdamNativeMixedPrecisionPlugin.clip_grad_by_norm, self._precision_plugin ) def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer if isinstance(optimizer, DistributedFusedAdam): return optimizer.state_dict(gather_on_root=False) else: return optimizer.state_dict() def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None ) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: checkpoint: dict containing model and trainer state filepath: write-target file's path storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin """ filepath = Path(filepath) filepath.mkdir(parents=True, exist_ok=True) local_optimizer_states = checkpoint.pop('optimizer_states') if self.is_global_zero: self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', storage_options=storage_options) self.checkpoint_io.save_checkpoint(local_optimizer_states, filepath / f'{self.global_rank:03d}_optim_states.pt', storage_options=storage_options) def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() checkpoint_path = Path(checkpoint_path) if checkpoint_path.is_file(): return super().load_checkpoint(self, str(checkpoint_path)) else: assert checkpoint_path.is_dir() global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') local_optimizer_states = self.checkpoint_io.load_checkpoint( checkpoint_path / f'{self.global_rank:03d}_optim_states.pt', map_location='cuda' ) global_states['optimizer_states'] = local_optimizer_states return global_states ================================================ FILE: training/src/utils/distributed.py ================================================ # Copied from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py # Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from contextlib import contextmanager import torch def init_distributed(cuda): """ Initializes distributed backend. :param cuda: (bool) if True initializes nccl backend, if False initializes gloo backend """ world_size = int(os.environ.get('WORLD_SIZE', 1)) distributed = (world_size > 1) if distributed: backend = 'nccl' if cuda else 'gloo' torch.distributed.init_process_group(backend=backend, init_method='env://') assert torch.distributed.is_initialized() return distributed def barrier(): """ Call torch.distributed.barrier() if distritubed is in use """ if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() def get_rank(): """ Gets distributed rank or returns zero if distributed is not initialized. """ if torch.distributed.is_available() and torch.distributed.is_initialized(): rank = torch.distributed.get_rank() else: rank = 0 return rank def get_world_size(): """ Gets total number of distributed workers or returns one if distributed is not initialized. """ if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() else: world_size = 1 return world_size def all_reduce_item(value, op='sum'): """ All-reduces single scalar value if distributed is in use """ if torch.distributed.is_available() and torch.distributed.is_initialized(): if op == 'sum' or op == 'mean': dop = torch.distributed.ReduceOp.SUM elif op == 'min': dop = torch.distributed.ReduceOp.MIN elif op == 'max': dop = torch.distributed.ReduceOp.MAX elif op == 'product': dop = torch.distributed.ReduceOp.PRODUCT else: raise RuntimeError('Unsupported reduce op') backend = torch.distributed.get_backend() if backend == torch.distributed.Backend.NCCL: device = torch.device('cuda') elif backend == torch.distributed.Backend.GLOO: device = torch.device('cpu') else: raise RuntimeError('Unsupported distributed backend') tensor = torch.tensor(value, device=device) torch.distributed.all_reduce(tensor, dop) if op == 'mean': tensor /= get_world_size() ret = tensor.item() else: ret = value return ret @contextmanager def sync_workers(): """ Yields distributed rank and synchronizes all workers on exit. """ rank = get_rank() yield rank barrier() ================================================ FILE: training/src/utils/ema.py ================================================ # Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py from __future__ import division from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy import contextlib import torch def to_float_maybe(x): return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x # Partially based on: # https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = True ): if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None parameters = list(parameters) self.shadow_params = [to_float_maybe(p.clone().detach()) for p in parameters if p.requires_grad] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage # is kept, no references to the model or its parameters will be # maintained, and the model will be cleaned up. self._params_refs = [weakref.ref(p) for p in parameters] def _get_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] ) -> Iterable[torch.nn.Parameter]: if parameters is None: parameters = [p() for p in self._params_refs] if any(p is None for p in parameters): raise ValueError( "(One of) the parameters with which this " "ExponentialMovingAverage " "was initialized no longer exists (was garbage collected);" " please either provide `parameters` explicitly or keep " "the model to which they belong from being garbage " "collected." ) return parameters else: parameters = list(parameters) if len(parameters) != len(self.shadow_params): raise ValueError( "Number of parameters passed as argument is different " "from number of shadow parameters maintained by this " "ExponentialMovingAverage" ) return parameters def update( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min( decay, (1 + self.num_updates) / (10 + self.num_updates) ) one_minus_decay = 1.0 - decay if parameters[0].device != self.shadow_params[0].device: self.to(device=parameters[0].device) with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param) def copy_to( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. If `None`, the parameters of with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [ param.clone() for param in parameters if param.requires_grad ] def restore( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.collected_params is None: raise RuntimeError( "This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`" ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ): r""" Context manager for validation/inference with averaged parameters. Equivalent to: ema.store() ema.copy_to() try: ... finally: ema.restore() Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.store(parameters) self.copy_to(parameters) try: yield finally: self.restore(parameters) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] if self.collected_params is not None: self.collected_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.collected_params ] return def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "num_updates": self.num_updates, "shadow_params": self.shadow_params, "collected_params": self.collected_params } def load_state_dict(self, state_dict: dict) -> None: r"""Loads the ExponentialMovingAverage state. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" self.collected_params = state_dict["collected_params"] if self.collected_params is not None: assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" assert len(self.collected_params) == len(self.shadow_params), \ "collected_params and shadow_params had different lengths" if len(self.shadow_params) == len(self._params_refs): # Consistent with torch.optim.Optimizer, cast things to consistent # device and dtype with the parameters params = [p() for p in self._params_refs] # If parameters have been garbage collected, just load the state # we were given without change. if not any(p is None for p in params): # ^ parameter references are still good for i, p in enumerate(params): self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to( device=p.device, dtype=p.dtype )) if self.collected_params is not None: self.collected_params[i] = self.collected_params[i].to( device=p.device, dtype=p.dtype ) else: raise ValueError( "Tried to `load_state_dict()` with the wrong number of " "parameters in the saved state." ) ================================================ FILE: training/src/utils/flops.py ================================================ # Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py import torch try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True except ImportError as e: has_deepspeed_profiling = False try: from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count_table from fvcore.nn import ActivationCountAnalysis has_fvcore_profiling = True except ImportError as e: FlopCountAnalysis = None ActivationCountAnalysis = None has_fvcore_profiling = False def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch.float32, batch_size=1, detailed=False): device, dtype = next(model.parameters()).device, next(model.parameters()).dtype flops, macs, params = get_model_profile( model=model, args=torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype), print_profile=detailed, # prints the model graph with the measured profile attached to each module detailed=detailed, # print the detailed profile warm_up=10, # the number of warm-ups before measuring the time of each module as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) output_file=None, # path to the output file. If None, the profiler prints to stdout. ignore_modules=None) # the list of modules to ignore in the profiling return macs, 0 # no activation count in DS def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.float32, max_depth=4, batch_size=1, detailed=False, force_cpu=False): if force_cpu: model = model.to('cpu') device, dtype = next(model.parameters()).device, next(model.parameters()).dtype example_input = torch.zeros((batch_size,) + input_size, device=device, dtype=input_dtype) fca = FlopCountAnalysis(model, example_input) aca = ActivationCountAnalysis(model, example_input) if detailed: print(flop_count_table(fca, max_depth=max_depth)) return fca, fca.total(), aca, aca.total() ================================================ FILE: training/src/utils/gpu_affinity.py ================================================ import collections import math import os import pathlib import re import pynvml pynvml.nvmlInit() def systemGetDriverVersion(): return pynvml.nvmlSystemGetDriverVersion() def deviceGetCount(): return pynvml.nvmlDeviceGetCount() class device: # assume nvml returns list of 64 bit ints _nvml_affinity_elements = math.ceil(os.cpu_count() / 64) def __init__(self, device_idx): super().__init__() self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx) def getName(self): return pynvml.nvmlDeviceGetName(self.handle) def getCpuAffinity(self): affinity_string = '' for j in pynvml.nvmlDeviceGetCpuAffinity( self.handle, device._nvml_affinity_elements ): # assume nvml returns list of 64 bit ints affinity_string = '{:064b}'.format(j) + affinity_string affinity_list = [int(x) for x in affinity_string] affinity_list.reverse() # so core 0 is in 0th element of list ret = [i for i, e in enumerate(affinity_list) if e != 0] return ret def set_socket_affinity(gpu_id): dev = device(gpu_id) affinity = dev.getCpuAffinity() os.sched_setaffinity(0, affinity) def set_single_affinity(gpu_id): dev = device(gpu_id) affinity = dev.getCpuAffinity() os.sched_setaffinity(0, affinity[:1]) def set_single_unique_affinity(gpu_id, nproc_per_node): devices = [device(i) for i in range(nproc_per_node)] socket_affinities = [dev.getCpuAffinity() for dev in devices] siblings_list = get_thread_siblings_list() siblings_dict = dict(siblings_list) # remove siblings for idx, socket_affinity in enumerate(socket_affinities): socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) affinities = [] assigned = [] for socket_affinity in socket_affinities: for core in socket_affinity: if core not in assigned: affinities.append([core]) assigned.append(core) break os.sched_setaffinity(0, affinities[gpu_id]) def set_socket_unique_affinity(gpu_id, nproc_per_node, mode): device_ids = [device(i) for i in range(nproc_per_node)] socket_affinities = [dev.getCpuAffinity() for dev in device_ids] siblings_list = get_thread_siblings_list() siblings_dict = dict(siblings_list) # remove siblings for idx, socket_affinity in enumerate(socket_affinities): socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values())) socket_affinities_to_device_ids = collections.defaultdict(list) for idx, socket_affinity in enumerate(socket_affinities): socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx) for socket_affinity, device_ids in socket_affinities_to_device_ids.items(): devices_per_group = len(device_ids) cores_per_device = len(socket_affinity) // devices_per_group for group_id, device_id in enumerate(device_ids): if device_id == gpu_id: if mode == 'interleaved': affinity = list(socket_affinity[group_id::devices_per_group]) elif mode == 'continuous': affinity = list(socket_affinity[group_id*cores_per_device:(group_id+1)*cores_per_device]) else: raise RuntimeError('Unknown set_socket_unique_affinity mode') # reintroduce siblings affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict] os.sched_setaffinity(0, affinity) def get_thread_siblings_list(): path = '/sys/devices/system/cpu/cpu*/topology/thread_siblings_list' thread_siblings_list = [] pattern = re.compile(r'(\d+)\D(\d+)') for fname in pathlib.Path(path[0]).glob(path[1:]): with open(fname) as f: content = f.read().strip() res = pattern.findall(content) if res: pair = tuple(map(int, res[0])) thread_siblings_list.append(pair) return thread_siblings_list def set_affinity(gpu_id, nproc_per_node, mode='socket'): if mode == 'socket': set_socket_affinity(gpu_id) elif mode == 'single': set_single_affinity(gpu_id) elif mode == 'single_unique': set_single_unique_affinity(gpu_id, nproc_per_node) elif mode == 'socket_unique_interleaved': set_socket_unique_affinity(gpu_id, nproc_per_node, 'interleaved') elif mode == 'socket_unique_continuous': set_socket_unique_affinity(gpu_id, nproc_per_node, 'continuous') else: raise RuntimeError('Unknown affinity mode') affinity = os.sched_getaffinity(0) return affinity ================================================ FILE: training/src/utils/utils.py ================================================ import logging import warnings from typing import List, Sequence import pytorch_lightning as pl import rich.syntax import rich.tree from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging class LoggingContext: def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() # implicit return of None => don't swallow exceptions def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger def extras(config: DictConfig) -> None: """A couple of optional utilities, controlled by main config file: - disabling warnings - forcing debug friendly configuration - verifying experiment name is set when running in experiment mode Modifies DictConfig in place. Args: config (DictConfig): Configuration composed by Hydra. """ log = get_logger(__name__) # disable python warnings if if config.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # verify experiment name is set when running in experiment mode if config.get("experiment_mode") and not config.get("name"): log.info( "Running in experiment mode without the experiment name specified! " "Use `python run.py mode=exp name=experiment_name`" ) log.info("Exiting...") exit() # force debugger friendly configuration if # debuggers don't like GPUs and multiprocessing if config.trainer.get("fast_dev_run"): log.info("Forcing debugger friendly configuration! ") if config.trainer.get("gpus"): config.trainer.gpus = 0 if config.datamodule.get("pin_memory"): config.datamodule.pin_memory = False if config.datamodule.get("num_workers"): config.datamodule.num_workers = 0 @rank_zero_only def print_config( config: DictConfig, fields: Sequence[str] = ( "trainer", "model", "datamodule", "train", "eval", "callbacks", "logger", "seed", "name", ), resolve: bool = True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. fields (Sequence[str], optional): Determines which main fields from config will be printed and in what order. resolve (bool, optional): Whether to resolve reference fields of DictConfig. """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, DictConfig): branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) with open("config_tree.txt", "w") as fp: rich.print(tree, file=fp) def finish( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Makes sure everything closed properly.""" # without this sweeps with wandb logger might crash! for lg in logger: if isinstance(lg, pl.loggers.wandb.WandbLogger): import wandb wandb.finish() ================================================ FILE: training/tests/datamodules/test_language_modeling_hf.py ================================================ import os from pathlib import Path current_dir = Path(__file__).parent.absolute() import pytest import torch import dotenv from src.datamodules.language_modeling_hf import LMDataModule # load environment variables from `.env` file if it exists # recursively searches for `.env` in all folders starting from work dir dotenv.load_dotenv(override=True) def div_up(x: int, y: int) -> int: return (x + y - 1) // y # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170 def num_cpu_cores(): try: import psutil return psutil.cpu_count(logical=False) except ImportError: return len(os.sched_getaffinity(0)) class TestLMDataModule: def test_wikitext2(self): batch_size = 7 dataset_name = 'wikitext' dataset_config_name = 'wikitext-2-raw-v1' data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'wikitext-2' / 'cache' max_length = 1024 datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=False, batch_size=batch_size, num_workers=4) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 2391884 val_len = 247289 test_len = 283287 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) def test_wikitext103(self): batch_size = 7 dataset_name = 'wikitext' dataset_config_name = 'wikitext-103-raw-v1' data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'wikitext-103' / 'cache' max_length = 1024 datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=False, batch_size=batch_size, num_workers=4) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 117920140 val_len = 247289 test_len = 283287 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) def test_openwebtext(self): batch_size = 8 dataset_name = 'openwebtext' dataset_config_name = None data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'openwebtext' / 'cache' max_length = 1024 datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=True, batch_size=batch_size, num_workers=num_cpu_cores() // 2) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 9035582198 val_len = 4434897 test_len = 4434897 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) def test_lambada(self): batch_size = 8 dataset_name = 'lambada' dataset_config_name = None data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'lambada' / 'cache' max_length = 1024 datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=True, batch_size=batch_size, num_workers=64) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 9035582198 val_len = 4434897 test_len = 4434897 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) def test_the_pile(self): batch_size = 8 dataset_name = 'the_pile' dataset_config_name = None data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'the_pile' / 'cache' max_length = 2048 # Dataset is too large to fit into memory, need to use disk for concatenation datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=True, batch_size=batch_size, num_workers=num_cpu_cores() // 2, use_shmem=False) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 374337375694 val_len = 383326395 test_len = 373297018 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) def test_pg19(self): batch_size = 8 dataset_name = 'pg19' dataset_config_name = None data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data')) cache_dir = data_dir / 'pg19' / 'cache' max_length = 2048 # Dataset is too large to fit into memory, need to use disk for concatenation datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2', dataset_config_name=dataset_config_name, max_length=max_length, cache_dir=cache_dir, add_eos=True, batch_size=batch_size, num_workers=num_cpu_cores() // 2) datamodule.prepare_data() datamodule.setup(stage='fit') train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() datamodule.setup(stage='test') test_loader = datamodule.test_dataloader() train_len = 3066544128 val_len = 4653056 test_len = 10584064 assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size) assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size) assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size) for loader in [train_loader, val_loader, test_loader]: x, y = next(iter(loader)) assert x.dim() == 2 assert x.shape == (batch_size, max_length) assert x.dtype == torch.long assert torch.allclose(x[:, 1:], y[:, :-1]) ================================================ FILE: usage.md ================================================ # FlashAttention adoption We've been very happy to see FlashAttention being adopted by many organizations and research labs to speed up their training / inference. This page contains a partial list of places where FlashAttention is being used. If you'd like to add links to your organization / product / codebase, please open a PR or email us. We'd very much like to hear from you! ## Integrated into machine learning frameworks - Pytorch: [integrated](https://github.com/pytorch/pytorch/pull/81434) into core Pytorch in nn.Transformer. - Huggingface's [transformers](https://github.com/huggingface/transformers) library. [On-going](https://github.com/huggingface/transformers/pull/18439), blogpost coming soon. - Microsoft's [DeepSpeed](https://github.com/microsoft/DeepSpeed): FlashAttention is [integrated](https://github.com/microsoft/DeepSpeed/blob/ec13da6ba7cabc44bb4745a64a208b8580792954/deepspeed/ops/transformer/inference/triton_ops.py) into DeepSpeed's inference engine. - Nvidia's [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/pull/267). This library is a popular framework on training large transformer language models at scale. - MosaicML [Composer](https://github.com/mosaicml/composer) [library](https://www.mosaicml.com/blog/gpt-3-quality-for-500k). Composer is a library for efficient neural network training. - EleutherAI's [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/pull/725). This is a research library for training large language transformer models at scale based on NVIDIA's Megatron-LM and Microsoft's DeepSpeed. - PaddlePaddle: integrated into the framework with [API](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/nn/functional/flash_attention.py) `paddle.nn.functional.flash_attention`. ## MLPerf benchmarks [MLPerf](https://mlcommons.org/en/) is a competitive machine learning performance benchmark. FlashAttention yields the fastest BERT training on cloud instances in MLPerf training 2.0 (June 2022) and MLPerf training 2.1 (November 2022). - MLPerf 2.0: [IEEE Spectrum](https://spectrum.ieee.org/mlperf-rankings-2022) and [Forbes](ttps://www.forbes.com/sites/moorinsights/2022/07/12/google-dethrones-nvidia-in-latest-artificial-intelligence-benchmarking-tests/) articles about our submission to the MLPerf 2.0 benchmark using FlashAttention. - MLPerf 2.1 - collaboration between [Azure and Hazy Research](https://techcommunity.microsoft.com/t5/azure-high-performance-computing/azure-collaborates-with-hazy-research-and-nvidia-to-achieve/ba-p/3667511): for the first time, we can train MLPerf BERT in under 2 minutes on 16 nodes. - MLPerf 2.1 - [Nvidia](https://developer.nvidia.com/blog/leading-mlperf-training-2-1-with-full-stack-optimizations-for-ai/): Nvidia uses techniques from FlashAttention to make their (already extremely optimized) BERT implementation go even faster. - MLPerf 2.1 - [MosaicML](https://www.mosaicml.com/blog/mlperf-nlp-nov2022): FlashAttention helps train BERT 2.7x faster in the open division. ## Language model training & inference - [PubMedGPT 2.7B](https://crfm.stanford.edu/2022/12/15/pubmedgpt.html), a domain-specific LLM for biomedicine, by Stanford CRFM, trained on [MosaicML](https://www.mosaicml.com/blog/introducing-pubmed-gpt) Cloud. Just using FlashAttention nearly halves the total training time. - Meta's [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/) uses FlashAttention as part of their approach to speed up Transformer inference (up to 5.3x on BERT). - Nvidia's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) is a state-of-the-art Transformer inference library. As of version [5.2](https://github.com/NVIDIA/FasterTransformer/commit/b672f49e256ba7a2d4fc9691d270b60b7fc1a2ff), FlashAttention is used as a component of FasterTransformer to speed up GPT inference. - [Kernl](https://github.com/ELS-RD/kernl) is a library for fast Transformer inference. They use FlashAttention as part of their [approach](https://twitter.com/pommedeterre33/status/1585284221014245377) to speed up Transformers by up to 12x. ## Diffusion model training and inference - Huggingface's [diffusers](https://github.com/huggingface/diffusers) library for diffusion models. FlashAttention is integrated into [diffusers v0.7.0](https://github.com/huggingface/diffusers/releases/tag/v0.7.0). Up to 2x faster inference and lower memory usage. - Colossal-AI's [implementation](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) of Stable Diffusion: with FlashAttention as one of its components, it speeds up pretraining by up to 6.5x, and reduces the hardware cost of fine-tuning by 7x. - Meta's [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/) with FlashAttention one of the components, is currently the [fastest](https://twitter.com/bing_xu_/status/1590447334055632897) Stable Diffusion inference engine that we know of. - Stable Diffusion inference from [Labml.ai](https://twitter.com/labmlai/status/1573634095732490240): 50% speedup. - Our own Stable Diffusion [fork](https://twitter.com/realDanFu/status/1580641495991754752) uses FlashAttention to get 3-4x speedup compared to the original version. ## Other models - [Uni-Fold](https://github.com/dptech-corp/Uni-Fold): Uni-Fold is an open-source platform for developing protein models beyond AlphaFold. With FlashAttention, Uni-Fold is 2.6x [faster](https://twitter.com/guolin_ke/status/1580532071901995008) than AlphaFold. - [OpenFold](https://github.com/aqlaboratory/openfold): a trainable, memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2. With FlashAttention as one of its [components](https://twitter.com/gahdritz/status/1595420944880779266), it is up to 3x faster than AlphaFold2 to run inference on short sequences, and can predict 2x longer structures. ## Different implementations - [Triton](https://github.com/openai/triton): an [implementation](https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py) of FlashAttention in Triton by Phil Tillet from OpenAI. Triton is a Python-based language and compiler for parallel programming. - [xformers](https://github.com/facebookresearch/xformers): The xformers team has implemented [memory-efficient attention](https://twitter.com/fvsmassa/status/1580229170629849089) in a similar spirit to FlashAttention. xformers dynamically dispatches to whichever implementation is available / faster. - [Jax](https://github.com/google/jax): an [implementation](https://github.com/lucidrains/flash-attention-jax) in Jax by [lucidrains](https://github.com/lucidrains/). - [Metal](https://developer.apple.com/metal): an [implementation](https://github.com/philipturner/metal-flash-attention) in Metal by Philip Turner. This ports FlashAttention to mobile GPU architectures such as Apple silicon.