SYMBOL INDEX (2142 symbols across 245 files) FILE: AI/racecheck_repro_1d_bulk.py function kernel (line 24) | def kernel(g_src: cute.Tensor, g_dst: cute.Tensor): function go (line 63) | def go(g_src, g_dst, stream): FILE: AI/racecheck_repro_1d_tensor.py function kernel (line 24) | def kernel(g_dst: cute.Tensor, tma_atom: cute.CopyAtom, tma_tensor: cute... function go (line 68) | def go(g_src, g_dst, stream): FILE: benchmarks/bench_sm90.py function parse_int_k (line 45) | def parse_int_k(s): function csv_ints (line 53) | def csv_ints(s): function parse_headdims (line 58) | def parse_headdims(s): function nheads_for_hdim (line 78) | def nheads_for_hdim(h): function fwd_flops (line 82) | def fwd_flops(batch, nheads, seqlen, hdim, hdim_v=None, causal=False): function bwd_flops (line 89) | def bwd_flops(batch, nheads, seqlen, hdim, causal=False, hdim_v=None): function get_causals (line 93) | def get_causals(args): function auto_batch (line 101) | def auto_batch(seqlen, batch_arg, total_tokens=32768): function bench_fwd (line 107) | def bench_fwd(batch, seqlen, nheads, hdim, causal, tile_m=None, tile_n=N... function bench_bwd (line 154) | def bench_bwd(batch, seqlen, nheads, hdim, causal, warmup=5, rep=30, hdi... function _get_default_bwd_config (line 257) | def _get_default_bwd_config(headdim, causal=False): function run_default (line 334) | def run_default(args): function run_sweep_tiles (line 367) | def run_sweep_tiles(args): function run_sweep_rs_overlap (line 397) | def run_sweep_rs_overlap(args): function run_compare_configs (line 427) | def run_compare_configs(args): function run_sweep_bwd_opts (line 452) | def run_sweep_bwd_opts(args): function main (line 489) | def main(): FILE: benchmarks/benchmark_alibi.py function generate_cos_sin (line 23) | def generate_cos_sin(seqlen, rotary_dim, device, dtype): function flash_rotary (line 31) | def flash_rotary(q, k, v, cos, sin, causal=False): function attn_bias_from_alibi_slopes (line 43) | def attn_bias_from_alibi_slopes( function flops (line 68) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): function efficiency (line 74) | def efficiency(flop, time): function attention_pytorch (line 78) | def attention_pytorch(q, k, v, dropout_p=0.0, causal=True, attn_bias=None): function time_fwd_bwd (line 110) | def time_fwd_bwd(func, *args, **kwargs): FILE: benchmarks/benchmark_attn.py function _make_bwd_fn (line 43) | def _make_bwd_fn(fwd_fn, g, inputs): function setup_standard (line 67) | def setup_standard(ctx): function setup_fa2 (line 76) | def setup_fa2(ctx): function setup_cudnn (line 95) | def setup_cudnn(ctx): function setup_fa3 (line 110) | def setup_fa3(ctx): function setup_fa4 (line 134) | def setup_fa4(ctx): function parse_int_k (line 173) | def parse_int_k(s): function csv_ints (line 181) | def csv_ints(s): function parse_headdims (line 186) | def parse_headdims(s): function csv_strs (line 206) | def csv_strs(s): function parse_args (line 211) | def parse_args(): function main (line 242) | def main(): FILE: benchmarks/benchmark_causal.py function attention_pytorch (line 20) | def attention_pytorch(qkv, dropout_p=0.0, causal=True): function time_fwd_bwd (line 122) | def time_fwd_bwd(func, *args, **kwargs): FILE: benchmarks/benchmark_flash_attention.py function flops (line 27) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): function efficiency (line 32) | def efficiency(flop, time): function attention_pytorch (line 36) | def attention_pytorch(qkv, dropout_p=0.0, causal=True): function time_fwd_bwd (line 65) | def time_fwd_bwd(func, *args, **kwargs): FILE: benchmarks/benchmark_gemm.py function benchmark_forward (line 12) | def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **... FILE: csrc/flash_attn/flash_api.cpp type FLASH_NAMESPACE (line 24) | namespace FLASH_NAMESPACE { function set_params_fprop (line 26) | void set_params_fprop(Flash_fwd_params ¶ms, function set_params_dgrad (line 161) | void set_params_dgrad(Flash_bwd_params ¶ms, function run_mha_fwd (line 243) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool f... function num_splits_heuristic (line 263) | inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs,... function set_params_splitkv (line 299) | std::tuple set_params_splitkv(Flash_fwd_params... function set_params_alibi (line 331) | void set_params_alibi(Flash_fwd_params ¶ms, std::optional function mha_varlen_fwd (line 514) | std::vector function run_mha_bwd (line 757) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { function mha_bwd (line 767) | std::vector function mha_varlen_bwd (line 973) | std::vector function mha_fwd_kvcache (line 1202) | std::vector function PYBIND11_MODULE (line 1478) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { FILE: csrc/flash_attn/src/alibi.h function namespace (line 11) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/block_info.h function namespace (line 8) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/dropout.h function namespace (line 11) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/flash.h function namespace (line 14) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/flash_bwd_kernel.h function namespace (line 23) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/flash_bwd_launch_template.h function namespace (line 16) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/flash_bwd_preprocess_kernel.h function namespace (line 18) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/flash_fwd_kernel.h function namespace (line 24) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/generate_kernels.py function get_fwd_template (line 17) | def get_fwd_template() -> str: function get_fwd_split_template (line 29) | def get_fwd_split_template() -> str: function get_bwd_template (line 38) | def get_bwd_template() -> str: class Kernel (line 51) | class Kernel: method template (line 59) | def template(self) -> str: method filename (line 73) | def filename(self) -> str: function get_all_kernels (line 76) | def get_all_kernels() -> List[Kernel]: function write_kernel (line 81) | def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: function main (line 88) | def main(output_dir: Optional[str]) -> None: FILE: csrc/flash_attn/src/hardware_info.h function get_current_device (line 24) | inline int get_current_device() { function get_num_sm (line 37) | inline int get_num_sm(int device) { FILE: csrc/flash_attn/src/mask.h function namespace (line 10) | namespace FLASH_NAMESPACE { function apply_mask_local (line 39) | void apply_mask_local(Tensor &tensor, const int col_idx_... FILE: csrc/flash_attn/src/rotary.h function namespace (line 14) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn/src/softmax.h function namespace (line 17) | namespace FLASH_NAMESPACE { function __forceinline__ (line 134) | __forceinline__ __device__ Softmax() {} FILE: csrc/flash_attn/src/utils.h function namespace (line 28) | namespace FLASH_NAMESPACE { FILE: csrc/flash_attn_ck/flash_api.cpp function PYBIND11_MODULE (line 114) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) FILE: csrc/flash_attn_ck/flash_common.cpp type flash (line 7) | namespace flash { function override_num_splits_if_necessary (line 8) | int override_num_splits_if_necessary(int batch, int nhead, int max_seq... FILE: csrc/flash_attn_ck/flash_common.hpp type flash (line 24) | namespace flash { function __global__ (line 25) | inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, u... function num_splits_heuristic_ck (line 38) | inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_S... FILE: csrc/flash_attn_ck/mha_bwd.cpp function fmha_bwd_traits (line 10) | fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, function fmha_bwd_args (line 41) | fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, function mha_bwd (line 211) | std::vector FILE: csrc/flash_attn_ck/mha_fwd.cpp function fmha_fwd_traits (line 10) | fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, function fmha_fwd_args (line 30) | fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, function mha_fwd (line 165) | std::vector FILE: csrc/flash_attn_ck/mha_fwd_kvcache.cpp function fmha_fwd_appendkv_traits (line 10) | fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype, function fmha_fwd_splitkv_traits (line 26) | fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &... function fmha_fwd_appendkv_args (line 45) | fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b, function fmha_fwd_splitkv_args (line 138) | fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, function mha_fwd_kvcache (line 272) | std::vector FILE: csrc/flash_attn_ck/mha_varlen_bwd.cpp function fmha_bwd_traits (line 10) | fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, function fmha_bwd_args (line 42) | fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, function mha_varlen_bwd (line 218) | std::vector FILE: csrc/flash_attn_ck/mha_varlen_fwd.cpp function fmha_fwd_traits (line 10) | fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, function fmha_fwd_splitkv_traits (line 30) | fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask... function fmha_fwd_args (line 49) | fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, function fmha_fwd_splitkv_args (line 187) | fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, function mha_varlen_fwd (line 320) | std::vector FILE: csrc/fused_dense_lib/fused_dense.cpp function linear_bias_wgrad (line 40) | std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d... function linear_act_forward (line 92) | std::vector linear_act_forward(at::Tensor input, at::Tensor ... function bias_act_linear_dgrad_bgrad (line 154) | std::vector bias_act_linear_dgrad_bgrad( function PYBIND11_MODULE (line 209) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { FILE: csrc/fused_dense_lib/setup.py function get_cuda_bare_metal_version (line 10) | def get_cuda_bare_metal_version(cuda_dir): function append_nvcc_threads (line 19) | def append_nvcc_threads(nvcc_extra_args): FILE: csrc/layer_norm/ln.h function namespace (line 13) | namespace layer_norm { FILE: csrc/layer_norm/ln_api.cpp type layer_norm (line 27) | namespace layer_norm { function get_type_id (line 36) | uint32_t get_type_id(torch::Dtype dtype){ function get_key (line 50) | uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype ... function dropout_add_ln_fwd (line 105) | std::vector dropout_add_ln_fwd(const at::Tensor &x0, //... function dropout_add_ln_bwd (line 282) | std::vector dropout_add_ln_bwd(const at::Tensor &dz, // ... function dropout_add_ln_parallel_residual_fwd (line 482) | std::vector dropout_add_ln_parallel_residual_fwd( function dropout_add_ln_parallel_residual_bwd (line 649) | std::vector dropout_add_ln_parallel_residual_bwd( function PYBIND11_MODULE (line 826) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { FILE: csrc/layer_norm/ln_kernel_traits.h function namespace (line 5) | namespace layer_norm { function Base (line 110) | struct Kernel_traits : public Base { FILE: csrc/layer_norm/setup.py function get_cuda_bare_metal_version (line 16) | def get_cuda_bare_metal_version(cuda_dir): function check_cuda_torch_binary_vs_bare_metal (line 25) | def check_cuda_torch_binary_vs_bare_metal(cuda_dir): function raise_if_cuda_home_none (line 43) | def raise_if_cuda_home_none(global_option: str) -> None: function append_nvcc_threads (line 53) | def append_nvcc_threads(nvcc_extra_args): FILE: flash_attn/bert_padding.py class IndexFirstAxis (line 8) | class IndexFirstAxis(torch.autograd.Function): method forward (line 10) | def forward(ctx, input, indices): method backward (line 22) | def backward(ctx, grad_output): class IndexPutFirstAxis (line 41) | class IndexPutFirstAxis(torch.autograd.Function): method forward (line 43) | def forward(ctx, values, indices, first_axis_dim): method backward (line 56) | def backward(ctx, grad_output): class IndexFirstAxisResidual (line 67) | class IndexFirstAxisResidual(torch.autograd.Function): method forward (line 69) | def forward(ctx, input, indices): method backward (line 82) | def backward(ctx, grad_output, grad_residual): function unpad_input (line 98) | def unpad_input(hidden_states, attention_mask, unused_mask=None): function unpad_input_for_concatenated_sequences (line 131) | def unpad_input_for_concatenated_sequences(hidden_states, attention_mask... function pad_input (line 204) | def pad_input(hidden_states, indices, batch, seqlen): FILE: flash_attn/cute/ampere_helpers.py function get_smem_layout_atom (line 8) | def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cu... function gemm (line 35) | def gemm( function gemm_rs (line 87) | def gemm_rs( FILE: flash_attn/cute/barrier.py function ld_acquire (line 9) | def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.... function red_relaxed (line 24) | def red_relaxed( function red_release (line 40) | def red_release( function wait_eq (line 56) | def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset... function arrive_inc (line 65) | def arrive_inc( FILE: flash_attn/cute/bench_utils.py function flops (line 15) | def flops( function attention_ref (line 46) | def attention_ref(q, k, v, causal=False): function _build_cudnn_graph (line 79) | def _build_cudnn_graph(io_dtype, tensors, build_fn): function cudnn_fwd_setup (line 99) | def cudnn_fwd_setup(q, k, v, causal=False, window_size_left=None): function cudnn_bwd_setup (line 141) | def cudnn_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=N... FILE: flash_attn/cute/benchmark.py function benchmark_forward (line 8) | def benchmark_forward( function benchmark_backward (line 30) | def benchmark_backward( function benchmark_combined (line 72) | def benchmark_combined( function benchmark_fwd_bwd (line 117) | def benchmark_fwd_bwd( function benchmark_all (line 154) | def benchmark_all( function pytorch_profiler (line 202) | def pytorch_profiler( function benchmark_memory (line 258) | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): FILE: flash_attn/cute/blackwell_helpers.py function gemm_w_idx (line 14) | def gemm_w_idx( function gemm_ptx_w_idx (line 42) | def gemm_ptx_w_idx( function gemm (line 77) | def gemm( function i64_to_i32x2 (line 90) | def i64_to_i32x2(i: int) -> Tuple[int, int]: function gemm_ptx (line 96) | def gemm_ptx( function gemm_ptx_loop (line 211) | def gemm_ptx_loop( function gemm_ptx_partial (line 374) | def gemm_ptx_partial( function gemm_ptx_partial1 (line 594) | def gemm_ptx_partial1( function gemm_ptx_precomputed (line 773) | def gemm_ptx_precomputed( function declare_ptx_smem_desc (line 952) | def declare_ptx_smem_desc( function declare_ptx_idesc (line 996) | def declare_ptx_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp, var_name: str = ... function gemm_ptx_precomputed_varname (line 1011) | def gemm_ptx_precomputed_varname( FILE: flash_attn/cute/block_info.py class BlockInfo (line 13) | class BlockInfo: method get_n_block_min_max (line 24) | def get_n_block_min_max( method get_m_block_min_max (line 58) | def get_m_block_min_max(self, seqlen_info: SeqlenInfoQK, n_block: Int3... method get_n_block_k_new_min_max (line 74) | def get_n_block_k_new_min_max( method get_n_block_min_causal_local_mask (line 105) | def get_n_block_min_causal_local_mask( method get_n_block_min_before_local_mask (line 124) | def get_n_block_min_before_local_mask( FILE: flash_attn/cute/block_sparse_utils.py function load_block_list (line 72) | def load_block_list( function finish_overlap_v_load (line 127) | def finish_overlap_v_load( function sparse_tensor_m_block (line 145) | def sparse_tensor_m_block( function produce_block_sparse_loads (line 160) | def produce_block_sparse_loads( function consume_block_sparse_loads (line 303) | def consume_block_sparse_loads( function load_block_list_sm100 (line 487) | def load_block_list_sm100( function produce_block_sparse_loads_sm100 (line 529) | def produce_block_sparse_loads_sm100( function get_total_block_count (line 622) | def get_total_block_count( function handle_block_sparse_empty_tile_correction_sm100 (line 643) | def handle_block_sparse_empty_tile_correction_sm100( function softmax_block_sparse_sm100 (line 756) | def softmax_block_sparse_sm100( function get_total_q_block_count_bwd (line 896) | def get_total_q_block_count_bwd( function produce_block_sparse_q_loads_bwd_sm100 (line 913) | def produce_block_sparse_q_loads_bwd_sm100( function get_block_sparse_iteration_info_bwd (line 1037) | def get_block_sparse_iteration_info_bwd( function get_m_block_from_iter_bwd (line 1069) | def get_m_block_from_iter_bwd( function _load_q_do_block_sm90 (line 1102) | def _load_q_do_block_sm90( function produce_block_sparse_q_loads_bwd_sm90 (line 1145) | def produce_block_sparse_q_loads_bwd_sm90( function consume_block_sparse_mma_bwd_sm90 (line 1241) | def consume_block_sparse_mma_bwd_sm90( function _store_one_dQaccum_sm90 (line 1347) | def _store_one_dQaccum_sm90( function dQaccum_store_block_sparse_bwd_sm90 (line 1377) | def dQaccum_store_block_sparse_bwd_sm90( FILE: flash_attn/cute/block_sparsity.py function ceildiv (line 13) | def ceildiv(a: int, b: int) -> int: class BlockSparseTensors (line 17) | class BlockSparseTensors(NamedTuple): method __new_from_mlir_values__ (line 23) | def __new_from_mlir_values__(self, values): class BlockSparseTensorsTorch (line 29) | class BlockSparseTensorsTorch(NamedTuple): function _expand_sparsity_tensor (line 37) | def _expand_sparsity_tensor( function _check_and_expand_block (line 60) | def _check_and_expand_block( function get_block_sparse_expected_shapes (line 90) | def get_block_sparse_expected_shapes( function infer_block_sparse_expected_shapes (line 108) | def infer_block_sparse_expected_shapes( function get_block_sparse_expected_shapes_bwd (line 202) | def get_block_sparse_expected_shapes_bwd( function normalize_block_sparse_tensors (line 225) | def normalize_block_sparse_tensors( function is_block_sparsity_enabled (line 269) | def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: function get_block_sparse_broadcast_pattern (line 273) | def get_block_sparse_broadcast_pattern( function normalize_block_sparse_config (line 305) | def normalize_block_sparse_config( function normalize_block_sparse_config_bwd (line 351) | def normalize_block_sparse_config_bwd( function to_cute_block_sparse_tensors (line 398) | def to_cute_block_sparse_tensors( function fast_sampling (line 437) | def fast_sampling(mask_mod): FILE: flash_attn/cute/cache_utils.py function get_cache_path (line 49) | def get_cache_path() -> Path: function _compute_source_fingerprint (line 59) | def _compute_source_fingerprint() -> str: class FileLock (line 88) | class FileLock: method __init__ (line 99) | def __init__( method _lock_label (line 120) | def _lock_label(self) -> str: method __enter__ (line 124) | def __enter__(self) -> "FileLock": method __exit__ (line 149) | def __exit__(self, exc_type, exc_val, exc_tb) -> None: class JITCache (line 156) | class JITCache: method __init__ (line 161) | def __init__(self): method __setitem__ (line 164) | def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) ->... method __getitem__ (line 167) | def __getitem__(self, key: CompileKeyType) -> CallableFunction: method __contains__ (line 170) | def __contains__(self, key: CompileKeyType) -> bool: method clear (line 173) | def clear(self) -> None: class JITPersistentCache (line 180) | class JITPersistentCache(JITCache): method __init__ (line 189) | def __init__(self, cache_path: Path): method __setitem__ (line 194) | def __setitem__(self, key: CompileKeyType, fn: JitCompiledFunction) ->... method __getitem__ (line 198) | def __getitem__(self, key: CompileKeyType) -> CallableFunction: method __contains__ (line 203) | def __contains__(self, key: CompileKeyType) -> bool: method _try_load_from_storage (line 210) | def _try_load_from_storage(self, key: CompileKeyType) -> bool: method _try_export_to_storage (line 234) | def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledF... method _key_to_hash (line 255) | def _key_to_hash(self, key: CompileKeyType) -> str: method _lock_path (line 258) | def _lock_path(self, sha256_hex: str) -> Path: method clear (line 261) | def clear(self) -> None: function get_jit_cache (line 271) | def get_jit_cache(name: str | None = None) -> JITCache: FILE: flash_attn/cute/compute_block_sparsity.py class BlockSparsityKernel (line 18) | class BlockSparsityKernel: method __init__ (line 35) | def __init__( method __call__ (line 50) | def __call__( method kernel (line 87) | def kernel( function compute_block_sparsity (line 277) | def compute_block_sparsity( FILE: flash_attn/cute/copy_utils.py function cvt_copy (line 17) | def cvt_copy( function load_s2r (line 36) | def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: function get_copy_atom (line 43) | def get_copy_atom( function make_tmem_copy (line 52) | def make_tmem_copy( function copy (line 66) | def copy( function tiled_copy_1d (line 81) | def tiled_copy_1d( function tiled_copy_2d (line 92) | def tiled_copy_2d( function atomic_add_fp32x4 (line 110) | def atomic_add_fp32x4( function set_block_rank (line 148) | def set_block_rank( function store_shared_remote_fp32x4 (line 167) | def store_shared_remote_fp32x4( function cpasync_bulk_s2cluster (line 211) | def cpasync_bulk_s2cluster( function cpasync_bulk_g2s (line 243) | def cpasync_bulk_g2s( function cpasync_reduce_bulk_add_f32 (line 267) | def cpasync_reduce_bulk_add_f32( function cpasync_bulk_get_copy_fn (line 291) | def cpasync_bulk_get_copy_fn( function tma_get_copy_fn (line 324) | def tma_get_copy_fn( function tma_producer_copy_fn (line 363) | def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.Pipe... FILE: flash_attn/cute/cute_dsl_ptxas.py function _log (line 25) | def _log(msg): function _get_ptx (line 30) | def _get_ptx(compiled_func) -> tuple[str, Path] | None: function _compile_ptx (line 45) | def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes: function _patched_load_cuda_library (line 81) | def _patched_load_cuda_library(self): function patch (line 132) | def patch(): FILE: flash_attn/cute/cute_dsl_utils.py function get_max_active_clusters (line 37) | def get_max_active_clusters(cluster_size): function get_device_capacity (line 42) | def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: class ArgumentsBase (line 47) | class ArgumentsBase(JitArgument): method __c_pointers__ (line 48) | def __c_pointers__(self): method __get_mlir_types__ (line 57) | def __get_mlir_types__(self): method __new_from_mlir_values__ (line 70) | def __new_from_mlir_values__(self, values): function load_cubin_module_data_patched (line 82) | def load_cubin_module_data_patched(cubin_data, filepath): function cute_compile_patched (line 87) | def cute_compile_patched(*args, **kwargs): function assume_strides_aligned (line 103) | def assume_strides_aligned(t): function assume_tensor_aligned (line 114) | def assume_tensor_aligned(t): function to_cute_tensor (line 121) | def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=Fa... function to_cute_aux_tensor (line 131) | def to_cute_aux_tensor(t, enable_tvm_ffi=True): function get_aux_tensor_metadata (line 149) | def get_aux_tensor_metadata(aux_tensors): function get_broadcast_dims (line 160) | def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: FILE: flash_attn/cute/fa_logging.py function _parse_log_level (line 38) | def _parse_log_level(raw: str) -> int: function _configure_default_handler (line 55) | def _configure_default_handler() -> None: function get_fa_log_level (line 73) | def get_fa_log_level() -> int: function set_fa_log_level (line 77) | def set_fa_log_level(level: int | str) -> None: function fa_log (line 90) | def fa_log(level: int, msg: str): function fa_printf (line 95) | def fa_printf(level: int, fmt, *args): FILE: flash_attn/cute/fast_math.py function clz (line 9) | def clz(x: Int32) -> Int32: FILE: flash_attn/cute/flash_bwd.py class FlashAttentionBackwardSm80 (line 28) | class FlashAttentionBackwardSm80: method __init__ (line 29) | def __init__( method can_implement (line 95) | def can_implement( method _check_type (line 141) | def _check_type( method _setup_attributes (line 183) | def _setup_attributes(self): method _get_tiled_mma (line 300) | def _get_tiled_mma(self): method _get_shared_storage_cls (line 322) | def _get_shared_storage_cls(self): method __call__ (line 364) | def __call__( method kernel (line 477) | def kernel( method compute_one_m_block (line 851) | def compute_one_m_block( method epilogue (line 1008) | def epilogue( method advance_pipeline (line 1137) | def advance_pipeline(self, pipeline_index, num_stages: cutlass.Constex... method load_K (line 1141) | def load_K( method load_V (line 1170) | def load_V( method load_Q_LSE (line 1198) | def load_Q_LSE( method load_dO_dPsum (line 1242) | def load_dO_dPsum( FILE: flash_attn/cute/flash_bwd_postprocess.py class FlashAttentionBackwardPostprocess (line 34) | class FlashAttentionBackwardPostprocess: method __init__ (line 35) | def __init__( method can_implement (line 70) | def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: method _get_tiled_mma (line 91) | def _get_tiled_mma(self): method _setup_attributes (line 132) | def _setup_attributes(self): method __call__ (line 211) | def __call__( method kernel (line 290) | def kernel( FILE: flash_attn/cute/flash_bwd_preprocess.py class FlashAttentionBackwardPreprocess (line 38) | class FlashAttentionBackwardPreprocess: method __init__ (line 39) | def __init__( method can_implement (line 69) | def can_implement(dtype, head_dim, tile_m, num_threads) -> bool: method _setup_attributes (line 94) | def _setup_attributes(self): method __call__ (line 125) | def __call__( method kernel (line 221) | def kernel( FILE: flash_attn/cute/flash_bwd_sm100.py class FlashAttentionBackwardSm100 (line 47) | class FlashAttentionBackwardSm100: method __init__ (line 50) | def __init__( method _setup_attributes (line 236) | def _setup_attributes(self): method _get_tiled_mma (line 266) | def _get_tiled_mma(self): method _setup_smem_layout (line 320) | def _setup_smem_layout(self): method __call__ (line 443) | def __call__( method kernel (line 1009) | def kernel( method relay (line 1621) | def relay( method load (line 1665) | def load( method mma (line 2196) | def mma( method split_wg (line 2699) | def split_wg( method apply_score_mod (line 2741) | def apply_score_mod( method apply_score_mod_bwd (line 2780) | def apply_score_mod_bwd( method compute_loop (line 2812) | def compute_loop( method dQacc_reduce (line 3417) | def dQacc_reduce( method epilogue_dKV (line 3657) | def epilogue_dKV( method epilogue_dK_or_dV_tma (line 3794) | def epilogue_dK_or_dV_tma( FILE: flash_attn/cute/flash_bwd_sm120.py class FlashAttentionBackwardSm120 (line 14) | class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): method can_implement (line 16) | def can_implement( FILE: flash_attn/cute/flash_bwd_sm90.py class FlashAttentionBackwardSm90 (line 45) | class FlashAttentionBackwardSm90: method __init__ (line 48) | def __init__( method can_implement (line 145) | def can_implement( method _check_type (line 169) | def _check_type( method _setup_attributes (line 200) | def _setup_attributes(self): method _get_tiled_mma (line 249) | def _get_tiled_mma(self): method _get_shared_storage_cls (line 299) | def _get_shared_storage_cls(self): method __call__ (line 337) | def __call__( method kernel (line 613) | def kernel( method load (line 833) | def load( method apply_score_mod (line 1001) | def apply_score_mod( method apply_score_mod_bwd (line 1044) | def apply_score_mod_bwd( method mma (line 1088) | def mma( method _get_stat (line 1428) | def _get_stat(tSrS: cute.Tensor, row: Int32, lane: Int32, shuffle: boo... method mma_one_m_block (line 1443) | def mma_one_m_block( method epilogue_dKV (line 1603) | def epilogue_dKV( method dQaccum_store (line 1738) | def dQaccum_store( FILE: flash_attn/cute/flash_fwd.py class FlashAttentionForwardBase (line 39) | class FlashAttentionForwardBase: method __init__ (line 41) | def __init__( method can_implement (line 113) | def can_implement( method _check_type (line 170) | def _check_type( method _setup_attributes (line 199) | def _setup_attributes(self): method _get_smem_layout_atom (line 295) | def _get_smem_layout_atom(self): method _get_tiled_mma (line 298) | def _get_tiled_mma(self): method _get_shared_storage_cls (line 301) | def _get_shared_storage_cls(self): method __call__ (line 305) | def __call__( method epilogue (line 324) | def epilogue( method advance_pipeline (line 449) | def advance_pipeline(self, pipeline_index): method load_Q (line 453) | def load_Q( method load_K (line 480) | def load_K( method load_V (line 526) | def load_V( class FlashAttentionForwardSm80 (line 576) | class FlashAttentionForwardSm80(FlashAttentionForwardBase): method _get_smem_layout_atom (line 577) | def _get_smem_layout_atom(self): method _get_tiled_mma (line 585) | def _get_tiled_mma(self): method _get_shared_storage_cls (line 598) | def _get_shared_storage_cls(self): method __call__ (line 620) | def __call__( method kernel (line 743) | def kernel( method compute_one_n_block (line 1082) | def compute_one_n_block( function __getattr__ (line 1195) | def __getattr__(name): FILE: flash_attn/cute/flash_fwd_combine.py class FlashAttentionForwardCombine (line 21) | class FlashAttentionForwardCombine: method __init__ (line 22) | def __init__( method can_implement (line 57) | def can_implement( method _setup_attributes (line 84) | def _setup_attributes(self): method __call__ (line 191) | def __call__( method kernel (line 324) | def kernel( method load_O_partial (line 668) | def load_O_partial( FILE: flash_attn/cute/flash_fwd_sm100.py class FlashAttentionForwardSm100 (line 64) | class FlashAttentionForwardSm100: method __init__ (line 66) | def __init__( method _setup_attributes (line 243) | def _setup_attributes(self): method __call__ (line 284) | def __call__( method kernel (line 665) | def kernel( method load (line 1122) | def load( method mma (line 1323) | def mma( method softmax_loop (line 1614) | def softmax_loop( method softmax_step (line 1952) | def softmax_step( method correction_loop (line 2089) | def correction_loop( method correction_rescale (line 2368) | def correction_rescale( method correction_epilogue (line 2419) | def correction_epilogue( method _store_O_to_gmem (line 2508) | def _store_O_to_gmem( method epilogue_s2g (line 2556) | def epilogue_s2g( method load_Q (line 2626) | def load_Q( method load_KV (line 2638) | def load_KV( method offset_kv_smem (line 2687) | def offset_kv_smem(self, sX: cute.Tensor, stage: Int32, phase: Int32): method apply_score_mod (line 2722) | def apply_score_mod( FILE: flash_attn/cute/flash_fwd_sm120.py class FlashAttentionForwardSm120 (line 14) | class FlashAttentionForwardSm120(FlashAttentionForwardSm80): method can_implement (line 20) | def can_implement( FILE: flash_attn/cute/flash_fwd_sm90.py class FlashAttentionForwardSm90 (line 51) | class FlashAttentionForwardSm90(FlashAttentionForwardBase): method __init__ (line 52) | def __init__( method _get_smem_layout_atom (line 71) | def _get_smem_layout_atom(self): method _get_tiled_mma (line 95) | def _get_tiled_mma(self): method _get_shared_storage_cls (line 119) | def _get_shared_storage_cls(self): method __call__ (line 157) | def __call__( method kernel (line 398) | def kernel( method load (line 625) | def load( method load_KV (line 862) | def load_KV( method mma (line 883) | def mma( method first_half_block_overlap (line 1236) | def first_half_block_overlap( method last_half_block_overlap (line 1292) | def last_half_block_overlap( method mma_one_n_block (line 1315) | def mma_one_n_block( method mma_one_n_block_intrawg_overlap (line 1377) | def mma_one_n_block_intrawg_overlap( method mma_init (line 1446) | def mma_init(self): method apply_score_mod (line 1456) | def apply_score_mod( method warp_scheduler_barrier_sync (line 1490) | def warp_scheduler_barrier_sync(self): method warp_scheduler_barrier_arrive (line 1499) | def warp_scheduler_barrier_arrive(self): FILE: flash_attn/cute/interface.py function _parse_arch_str (line 72) | def _parse_arch_str(arch_str): function _get_device_arch (line 83) | def _get_device_arch(): function _validate_head_dims (line 101) | def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capabili... class FwdConfig (line 120) | class FwdConfig: function _tile_size_fwd_sm90 (line 127) | def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_b... class BwdConfig (line 156) | class BwdConfig: function _tile_size_bwd_sm90 (line 172) | def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local): function maybe_contiguous (line 236) | def maybe_contiguous(x): function _validate_tensor (line 240) | def _validate_tensor(t, name, expected_shape, expected_dtype, expected_d... function num_splits_heuristic (line 255) | def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): function _resolve_causal_local_window (line 265) | def _resolve_causal_local_window(causal, window_size_left, window_size_r... function _flash_attn_fwd (line 288) | def _flash_attn_fwd( function make_fake_bwd_tensors (line 822) | def make_fake_bwd_tensors(dtype, has_gqa, varlen_q, varlen_k): function _compile_bwd_preprocess (line 866) | def _compile_bwd_preprocess( function _bwd_preprocess (line 886) | def _bwd_preprocess( function _compile_bwd_postprocess (line 907) | def _compile_bwd_postprocess( function _bwd_postprocess_convert (line 932) | def _bwd_postprocess_convert( function _flash_attn_bwd (line 956) | def _flash_attn_bwd( class FlashAttnFunc (line 1564) | class FlashAttnFunc(torch.autograd.Function): method forward (line 1566) | def forward( method backward (line 1624) | def backward(ctx, dout, dlse): class FlashAttnVarlenFunc (line 1648) | class FlashAttnVarlenFunc(torch.autograd.Function): method forward (line 1650) | def forward( method backward (line 1710) | def backward(ctx, dout, dlse): function flash_attn_func (line 1742) | def flash_attn_func( function flash_attn_varlen_func (line 1784) | def flash_attn_varlen_func( function _compile_fwd_combine (line 1832) | def _compile_fwd_combine( function _flash_attn_fwd_combine (line 1889) | def _flash_attn_fwd_combine( function flash_attn_combine (line 1983) | def flash_attn_combine( FILE: flash_attn/cute/mask.py function r2p_bitmask_below (line 19) | def r2p_bitmask_below(limit: Int32, s: int) -> Uint32: function r2p_bitmask_above (line 30) | def r2p_bitmask_above(limit: Int32, s: int) -> Uint32: function mask_r2p_lambda (line 41) | def mask_r2p_lambda( function sm90_col_to_r2p_idx (line 69) | def sm90_col_to_r2p_idx(col_limit: Int32) -> Int32: function row_to_r2p_idx (line 80) | def row_to_r2p_idx(x: Int32, num_rep: int, num_wg: int) -> Int32: class AttentionMask (line 103) | class AttentionMask: method seqlen_q (line 113) | def seqlen_q(self) -> Int32: method seqlen_k (line 117) | def seqlen_k(self) -> Int32: method apply_mask (line 121) | def apply_mask( method apply_mask_sm100 (line 370) | def apply_mask_sm100( method apply_mask_sm100_transposed (line 530) | def apply_mask_sm100_transposed( FILE: flash_attn/cute/mma_sm100_desc.py class Major (line 16) | class Major(IntEnum): # matrix “layout” in the ISA docs class ScaleIn (line 21) | class ScaleIn(IntEnum): # negate flags class Saturate (line 26) | class Saturate(IntEnum): class CFormat (line 31) | class CFormat(IntEnum): # 2-bit field (bits 4-5) class F16F32Format (line 37) | class F16F32Format(IntEnum): # 3-bit field (A/B element type) class S8Format (line 43) | class S8Format(IntEnum): class MXF8F6F4Format (line 48) | class MXF8F6F4Format(IntEnum): class MaxShift (line 56) | class MaxShift(IntEnum): function to_UMMA_format (line 68) | def to_UMMA_format(cutlass_type) -> int: function to_C_format (line 93) | def to_C_format(cutlass_type) -> int: function make_instr_desc (line 111) | def make_instr_desc( function mma_op_to_idesc (line 165) | def mma_op_to_idesc(op: cute.nvgpu.tcgen05.mma.MmaOp): class LayoutType (line 177) | class LayoutType(IntEnum): # occupies the top-3 bits [61:64) function _layout_type (line 191) | def _layout_type(swizzle: cute.Swizzle) -> LayoutType: function make_smem_desc_base (line 212) | def make_smem_desc_base(layout: cute.Layout, swizzle: cute.Swizzle, majo... function make_smem_desc_start_addr (line 285) | def make_smem_desc_start_addr(start_addr: cute.Pointer) -> cutlass.Int32: function smem_desc_base_from_tensor (line 290) | def smem_desc_base_from_tensor(sA: cute.Tensor, major: Major) -> int: FILE: flash_attn/cute/named_barrier.py class NamedBarrierFwd (line 6) | class NamedBarrierFwd(enum.IntEnum): class NamedBarrierFwdSm100 (line 15) | class NamedBarrierFwdSm100(enum.IntEnum): class NamedBarrierBwd (line 28) | class NamedBarrierBwd(enum.IntEnum): class NamedBarrierBwdSm100 (line 42) | class NamedBarrierBwdSm100(enum.IntEnum): FILE: flash_attn/cute/pack_gqa.py function pack_gqa_layout (line 14) | def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): function make_packgqa_tiled_tma_atom (line 42) | def make_packgqa_tiled_tma_atom( function unpack_gqa_layout (line 85) | def unpack_gqa_layout(T, qhead_per_kvhead, head_idx): class PackGQA (line 114) | class PackGQA: method __init__ (line 115) | def __init__( method compute_ptr (line 128) | def compute_ptr( method load_Q (line 148) | def load_Q( method store_LSE (line 193) | def store_LSE( method store_O (line 228) | def store_O( FILE: flash_attn/cute/paged_kv.py class PagedKVManager (line 17) | class PagedKVManager(ParamsBase): method create (line 46) | def create( method load_page_table (line 136) | def load_page_table(self, n_block: Int32): method compute_X_ptr (line 157) | def compute_X_ptr(self, K_or_V: str): method load_KV (line 173) | def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str): FILE: flash_attn/cute/pipeline.py class PipelineStateSimple (line 20) | class PipelineStateSimple: method __init__ (line 27) | def __init__(self, stages: int, phase_index: Int32): method clone (line 34) | def clone(self) -> "PipelineStateSimple": method stages (line 38) | def stages(self) -> int: method index (line 43) | def index(self) -> Int32: method phase (line 52) | def phase(self) -> Int32: method advance (line 63) | def advance(self): method __extract_mlir_values__ (line 84) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 88) | def __new_from_mlir_values__(self, values): function make_pipeline_state (line 92) | def make_pipeline_state(type: PipelineUserType, stages: int): class NamedBarrier (line 106) | class NamedBarrier(NamedBarrierOg): method create (line 108) | def create(*args, **kwargs): method arrive_w_index (line 115) | def arrive_w_index(self, index: Int32, *, loc=None, ip=None) -> None: method arrive_and_wait_w_index (line 128) | def arrive_and_wait_w_index(self, index: Int32, *, loc=None, ip=None) ... class PipelineAsync (line 138) | class PipelineAsync(PipelineAsyncOg): method create (line 140) | def create(*args, **kwargs): method producer_acquire_w_index_phase (line 148) | def producer_acquire_w_index_phase( method producer_commit_w_index (line 165) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): method consumer_wait_w_index_phase (line 169) | def consumer_wait_w_index_phase( method consumer_release_w_index (line 186) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): class PipelineTmaAsync (line 191) | class PipelineTmaAsync(PipelineTmaAsyncOg): method create (line 197) | def create(*args, **kwargs): method producer_acquire (line 204) | def producer_acquire( method producer_acquire_w_index_phase (line 229) | def producer_acquire_w_index_phase( method consumer_wait_w_index_phase (line 247) | def consumer_wait_w_index_phase( method consumer_release_w_index (line 264) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): class PipelineTmaUmma (line 275) | class PipelineTmaUmma(PipelineTmaUmmaOg): method create (line 281) | def create(*args, **kwargs): method producer_acquire (line 289) | def producer_acquire( method producer_acquire_w_index_phase (line 328) | def producer_acquire_w_index_phase( method consumer_wait_w_index_phase (line 354) | def consumer_wait_w_index_phase( method consumer_release_w_index (line 371) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): class PipelineUmmaAsync (line 379) | class PipelineUmmaAsync(PipelineUmmaAsyncOg): method create (line 381) | def create(*args, **kwargs): method producer_acquire_w_index_phase (line 388) | def producer_acquire_w_index_phase( method producer_commit_w_index (line 405) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): method consumer_wait_w_index_phase (line 412) | def consumer_wait_w_index_phase( method consumer_release_w_index (line 429) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): class PipelineAsyncUmma (line 434) | class PipelineAsyncUmma(PipelineAsyncUmmaOg): method create (line 436) | def create(*args, **kwargs): method producer_acquire_w_index_phase (line 443) | def producer_acquire_w_index_phase( method producer_commit_w_index (line 460) | def producer_commit_w_index(self, index: Int32, *, loc=None, ip=None): method consumer_wait_w_index_phase (line 464) | def consumer_wait_w_index_phase( method consumer_release_w_index (line 481) | def consumer_release_w_index(self, index: Int32, *, loc=None, ip=None): FILE: flash_attn/cute/seqlen_info.py class SeqlenInfo (line 18) | class SeqlenInfo: method create (line 25) | def create( method offset_batch (line 47) | def offset_batch( class SeqlenInfoQK (line 67) | class SeqlenInfoQK: method create (line 80) | def create( method offset_batch_Q (line 132) | def offset_batch_Q( method offset_batch_K (line 171) | def offset_batch_K( class SeqlenInfoQKNewK (line 204) | class SeqlenInfoQKNewK: method create (line 227) | def create( FILE: flash_attn/cute/sm90_config_search.py function _divisors (line 20) | def _divisors(n): function _acc_regs (line 24) | def _acc_regs(M, N, num_wg): function _check_mma (line 29) | def _check_mma(M, N, num_wg, atom_layout_m, swap_AB): function _mma_traffic (line 44) | def _mma_traffic(M_eff, N_eff, K_red, num_wg, wg_n, is_rs=False): function _check_bwd_config (line 61) | def _check_bwd_config( function find_feasible_bwd_configs (line 174) | def find_feasible_bwd_configs( function print_bwd_configs (line 224) | def print_bwd_configs(configs, max_results=20): function _check_fwd_config (line 260) | def _check_fwd_config(hdim, hdimv, tile_n, num_wg, pv_is_rs, overlap_wg): function find_feasible_fwd_configs (line 315) | def find_feasible_fwd_configs( function print_fwd_configs (line 336) | def print_fwd_configs(configs, max_results=20): FILE: flash_attn/cute/softmax.py class Softmax (line 19) | class Softmax(ParamsBase): method create (line 28) | def create( method reset (line 38) | def reset(self) -> None: method _compute_row_max (line 42) | def _compute_row_max( method _compute_row_sum (line 47) | def _compute_row_sum( method online_softmax (line 53) | def online_softmax( method finalize (line 119) | def finalize( method rescale_O (line 156) | def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: class SoftmaxSm100 (line 170) | class SoftmaxSm100(Softmax): method create (line 174) | def create( method update_row_max (line 194) | def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> ... method update_row_sum (line 213) | def update_row_sum( method scale_subtract_rowmax (line 223) | def scale_subtract_rowmax( method apply_exp2_convert (line 238) | def apply_exp2_convert( method scale_apply_exp2_convert (line 282) | def scale_apply_exp2_convert( function floor_if_packed (line 333) | def floor_if_packed( function apply_score_mod_inner (line 344) | def apply_score_mod_inner( function apply_score_mod_bwd_inner (line 474) | def apply_score_mod_bwd_inner( FILE: flash_attn/cute/testing.py class IndexFirstAxis (line 13) | class IndexFirstAxis(torch.autograd.Function): method forward (line 15) | def forward(ctx, input, indices): method backward (line 27) | def backward(ctx, grad_output): class IndexPutFirstAxis (line 44) | class IndexPutFirstAxis(torch.autograd.Function): method forward (line 46) | def forward(ctx, values, indices, first_axis_dim): method backward (line 57) | def backward(ctx, grad_output): function unpad_input (line 66) | def unpad_input(hidden_states, attention_mask, unused_mask=None): function pad_input (line 89) | def pad_input(hidden_states, indices, batch, seqlen): function generate_random_padding_mask (line 94) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 124) | def generate_qkv( function construct_local_mask (line 249) | def construct_local_mask( function construct_chunk_mask (line 291) | def construct_chunk_mask( function attention_ref (line 323) | def attention_ref( function maybe_fake_tensor_mode (line 437) | def maybe_fake_tensor_mode(fake: bool = True): function is_fake_mode (line 455) | def is_fake_mode() -> bool: FILE: flash_attn/cute/tile_scheduler.py class WorkTileInfo (line 23) | class WorkTileInfo(cutlass.utils.WorkTileInfo): method __new_from_mlir_values__ (line 27) | def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTil... class TileSchedulerArguments (line 35) | class TileSchedulerArguments(ParamsBase): class SingleTileScheduler (line 56) | class SingleTileScheduler: class Params (line 58) | class Params(ParamsBase): method create (line 68) | def create( method __init__ (line 81) | def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None,... method to_underlying_arguments (line 89) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,... method create (line 93) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileSchedul... method get_grid_shape (line 105) | def get_grid_shape( method get_current_work (line 119) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: method initial_work_tile_info (line 130) | def initial_work_tile_info(self, *, loc=None, ip=None): method prefetch_next_work (line 133) | def prefetch_next_work(self, *, loc=None, ip=None): method advance_to_next_work (line 136) | def advance_to_next_work(self, *, loc=None, ip=None): method __extract_mlir_values__ (line 139) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 147) | def __new_from_mlir_values__(self, values): class StaticPersistentTileScheduler (line 155) | class StaticPersistentTileScheduler: class Params (line 157) | class Params(ParamsBase): method create (line 164) | def create( method __init__ (line 176) | def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=No... method to_underlying_arguments (line 183) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,... method create (line 187) | def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentT... method get_grid_shape (line 196) | def get_grid_shape( method get_current_work (line 210) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: method initial_work_tile_info (line 220) | def initial_work_tile_info(self, *, loc=None, ip=None): method prefetch_next_work (line 223) | def prefetch_next_work(self, *, loc=None, ip=None): method advance_to_next_work (line 226) | def advance_to_next_work(self, *, loc=None, ip=None): method __extract_mlir_values__ (line 232) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 240) | def __new_from_mlir_values__(self, values): class SingleTileLPTScheduler (line 251) | class SingleTileLPTScheduler: class Params (line 253) | class Params(ParamsBase): method create (line 268) | def create( method __init__ (line 303) | def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, ... method to_underlying_arguments (line 311) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,... method create (line 316) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTSche... method get_grid_shape (line 322) | def get_grid_shape( method get_current_work (line 331) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: method initial_work_tile_info (line 351) | def initial_work_tile_info(self, *, loc=None, ip=None): method prefetch_next_work (line 354) | def prefetch_next_work(self, *, loc=None, ip=None): method advance_to_next_work (line 357) | def advance_to_next_work(self, *, loc=None, ip=None): method __extract_mlir_values__ (line 361) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 369) | def __new_from_mlir_values__(self, values): class SingleTileLPTBwdScheduler (line 377) | class SingleTileLPTBwdScheduler: class Params (line 379) | class Params(ParamsBase): method create (line 393) | def create( method __init__ (line 426) | def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=No... method to_underlying_arguments (line 433) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,... method create (line 438) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTBwdS... method get_grid_shape (line 444) | def get_grid_shape( method get_current_work (line 453) | def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.Work... method initial_work_tile_info (line 475) | def initial_work_tile_info(self, *, loc=None, ip=None): method prefetch_next_work (line 478) | def prefetch_next_work(self, *, loc=None, ip=None): method advance_to_next_work (line 481) | def advance_to_next_work(self, *, loc=None, ip=None): method __extract_mlir_values__ (line 485) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 493) | def __new_from_mlir_values__(self, values): class SingleTileVarlenScheduler (line 501) | class SingleTileVarlenScheduler: class Params (line 503) | class Params(ParamsBase): method create (line 520) | def create( method __init__ (line 547) | def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, ... method to_underlying_arguments (line 556) | def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None,... method create (line 560) | def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenS... method get_grid_shape (line 566) | def get_grid_shape( method _get_num_m_blocks (line 581) | def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: method get_current_work (line 604) | def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: method initial_work_tile_info (line 701) | def initial_work_tile_info(self, *, loc=None, ip=None): method prefetch_next_work (line 704) | def prefetch_next_work(self, *, loc=None, ip=None): method advance_to_next_work (line 707) | def advance_to_next_work(self, *, loc=None, ip=None): method __extract_mlir_values__ (line 711) | def __extract_mlir_values__(self): method __new_from_mlir_values__ (line 719) | def __new_from_mlir_values__(self, values): FILE: flash_attn/cute/utils.py function _compute_base_hash (line 59) | def _compute_base_hash(func: Callable) -> str: function hash_callable (line 78) | def hash_callable( function create_softcap_scoremod (line 116) | def create_softcap_scoremod(softcap_val): function compute_softmax_scale_log2 (line 130) | def compute_softmax_scale_log2(softmax_scale, score_mod): function compute_fastdiv_mods (line 145) | def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors... function convert_from_dlpack (line 161) | def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) ->... function convert_from_dlpack_leading_static (line 171) | def convert_from_dlpack_leading_static( function make_tiled_copy_A (line 183) | def make_tiled_copy_A( function make_tiled_copy_B (line 192) | def make_tiled_copy_B( function mma_make_fragment_A (line 201) | def mma_make_fragment_A( function mma_make_fragment_B (line 210) | def mma_make_fragment_B( function get_smem_store_atom (line 219) | def get_smem_store_atom( function warp_reduce (line 236) | def warp_reduce( function fmax (line 254) | def fmax( function fmax_reduce (line 286) | def fmax_reduce( function fadd_reduce (line 337) | def fadd_reduce( function atomic_add_fp32 (line 378) | def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=N... function elem_pointer (line 400) | def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None... function predicate_k (line 405) | def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: function canonical_warp_group_idx (line 420) | def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: function shuffle_sync (line 448) | def shuffle_sync( function shl_u32 (line 468) | def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=... function shr_u32 (line 502) | def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=... function warp_prefix_sum (line 526) | def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = ... function cvt_f16x2_f32 (line 541) | def cvt_f16x2_f32( function cvt_f16 (line 559) | def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... function cvt_f16 (line 563) | def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor:... function cvt_f16 (line 567) | def cvt_f16(src: cute.Tensor, dst_or_dtype): function evaluate_polynomial (line 600) | def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=No... function evaluate_polynomial_2 (line 610) | def evaluate_polynomial_2( function add_round_down (line 621) | def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ... function combine_int_frac_ex2 (line 637) | def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=N... function ex2_emulation (line 664) | def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None... function ex2_emulation_2 (line 681) | def ex2_emulation_2( function e2e_asm2 (line 702) | def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Floa... function domain_offset_aligned (line 749) | def domain_offset_aligned( function scalar_to_ssa (line 764) | def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: function ssa_to_scalar (line 771) | def ssa_to_scalar(val): FILE: flash_attn/flash_attn_interface.py function maybe_contiguous (line 27) | def maybe_contiguous(x): function _get_block_size_n (line 31) | def _get_block_size_n(device, head_dim, is_dropout, is_causal): function round_multiple (line 57) | def round_multiple(x, m): function noop_custom_op_wrapper (line 68) | def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_typ... function noop_register_fake_wrapper (line 74) | def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): function _flash_attn_forward (line 85) | def _flash_attn_forward( function _flash_attn_forward_fake (line 118) | def _flash_attn_forward_fake( function _flash_attn_varlen_forward (line 154) | def _flash_attn_varlen_forward( function _flash_attn_varlen_forward_fake (line 205) | def _flash_attn_varlen_forward_fake( function _flash_attn_backward (line 250) | def _flash_attn_backward( function _flash_attn_backward_fake (line 302) | def _flash_attn_backward_fake( function _flash_attn_varlen_backward (line 345) | def _flash_attn_varlen_backward( function _flash_attn_varlen_backward_fake (line 409) | def _flash_attn_varlen_backward_fake( class FlashAttnQKVPackedFunc (line 458) | class FlashAttnQKVPackedFunc(torch.autograd.Function): method forward (line 460) | def forward( method backward (line 508) | def backward(ctx, dout, *args): class FlashAttnVarlenQKVPackedFunc (line 540) | class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): method forward (line 542) | def forward( method backward (line 598) | def backward(ctx, dout, *args): class FlashAttnKVPackedFunc (line 634) | class FlashAttnKVPackedFunc(torch.autograd.Function): method forward (line 636) | def forward( method backward (line 687) | def backward(ctx, dout, *args): class FlashAttnVarlenKVPackedFunc (line 721) | class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): method forward (line 723) | def forward( method backward (line 787) | def backward(ctx, dout, *args): class FlashAttnFunc (line 825) | class FlashAttnFunc(torch.autograd.Function): method forward (line 827) | def forward( method backward (line 878) | def backward(ctx, dout, *args): class FlashAttnVarlenFunc (line 911) | class FlashAttnVarlenFunc(torch.autograd.Function): method forward (line 913) | def forward( method backward (line 979) | def backward(ctx, dout, *args): function flash_attn_qkvpacked_func (line 1016) | def flash_attn_qkvpacked_func( function flash_attn_kvpacked_func (line 1075) | def flash_attn_kvpacked_func( function flash_attn_func (line 1153) | def flash_attn_func( function flash_attn_varlen_qkvpacked_func (line 1230) | def flash_attn_varlen_qkvpacked_func( function flash_attn_varlen_kvpacked_func (line 1296) | def flash_attn_varlen_kvpacked_func( function flash_attn_varlen_func (line 1388) | def flash_attn_varlen_func( function flash_attn_with_kvcache (line 1482) | def flash_attn_with_kvcache( FILE: flash_attn/flash_attn_triton.py function _fwd_kernel (line 66) | def _fwd_kernel( function _bwd_preprocess_do_o_dot (line 288) | def _bwd_preprocess_do_o_dot( function _bwd_store_dk_dv (line 333) | def _bwd_store_dk_dv( function _bwd_kernel_one_col_block (line 365) | def _bwd_kernel_one_col_block( function init_to_zero (line 633) | def init_to_zero(name): function _bwd_kernel (line 668) | def _bwd_kernel( function _flash_attn_forward (line 812) | def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=... function _flash_attn_backward (line 894) | def _flash_attn_backward( class FlashAttnQKVPackedFunc (line 1013) | class FlashAttnQKVPackedFunc(torch.autograd.Function): method forward (line 1015) | def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): method backward (line 1038) | def backward(ctx, do): class FlashAttnKVPackedFunc (line 1065) | class FlashAttnKVPackedFunc(torch.autograd.Function): method forward (line 1067) | def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): method backward (line 1085) | def backward(ctx, do): class FlashAttnFunc (line 1114) | class FlashAttnFunc(torch.autograd.Function): method forward (line 1116) | def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): method backward (line 1134) | def backward(ctx, do): FILE: flash_attn/flash_attn_triton_og.py function _fwd_kernel (line 19) | def _fwd_kernel( function _bwd_preprocess (line 121) | def _bwd_preprocess( function _bwd_kernel (line 145) | def _bwd_kernel( class _attention (line 248) | class _attention(torch.autograd.Function): method forward (line 250) | def forward(ctx, q, k, v, sm_scale): method backward (line 307) | def backward(ctx, do): FILE: flash_attn/flash_blocksparse_attention.py class FlashBlocksparseAttention (line 15) | class FlashBlocksparseAttention(nn.Module): method __init__ (line 26) | def __init__( method forward (line 48) | def forward( class FlashBlocksparseMHA (line 154) | class FlashBlocksparseMHA(nn.Module): method __init__ (line 155) | def __init__( method forward (line 189) | def forward( FILE: flash_attn/flash_blocksparse_attn_interface.py function convert_blockmask (line 7) | def convert_blockmask(blockmask, causal): function _flash_blocksparse_attn_forward (line 42) | def _flash_blocksparse_attn_forward( function _flash_blocksparse_attn_backward (line 54) | def _flash_blocksparse_attn_backward( class FlashBlocksparseAttnFun (line 86) | class FlashBlocksparseAttnFun(torch.autograd.Function): method forward (line 88) | def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax... method backward (line 111) | def backward(ctx, dout): class FlashBlocksparseAttnFunWithS (line 137) | class FlashBlocksparseAttnFunWithS(torch.autograd.Function): method forward (line 139) | def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax... method backward (line 162) | def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): function flash_blocksparse_attn_func (line 185) | def flash_blocksparse_attn_func( FILE: flash_attn/layers/patch_embed.py class PatchEmbed (line 17) | class PatchEmbed(nn.Module): method __init__ (line 20) | def __init__( method forward (line 46) | def forward(self, x): FILE: flash_attn/layers/rotary.py function rotate_half (line 14) | def rotate_half(x, interleaved=False): function apply_rotary_emb_torch (line 23) | def apply_rotary_emb_torch(x, cos, sin, interleaved=False): class ApplyRotaryEmb (line 38) | class ApplyRotaryEmb(torch.autograd.Function): method forward (line 40) | def forward( method backward (line 73) | def backward(ctx, do): function apply_rotary_emb (line 93) | def apply_rotary_emb( function _apply_rotary_emb_qkv (line 130) | def _apply_rotary_emb_qkv( class ApplyRotaryEmbQKV_ (line 194) | class ApplyRotaryEmbQKV_(torch.autograd.Function): method forward (line 196) | def forward( method backward (line 223) | def backward(ctx, dqkv): function apply_rotary_emb_qkv_ (line 236) | def apply_rotary_emb_qkv_( class ApplyRotaryEmbKV_ (line 267) | class ApplyRotaryEmbKV_(torch.autograd.Function): method forward (line 270) | def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Unio... method backward (line 287) | def backward(ctx, dkv): function apply_rotary_emb_kv_ (line 308) | def apply_rotary_emb_kv_( class RotaryEmbedding (line 331) | class RotaryEmbedding(torch.nn.Module): method __init__ (line 349) | def __init__( method _compute_inv_freq (line 382) | def _compute_inv_freq(self, device=None): method _update_cos_sin_cache (line 388) | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): method forward (line 429) | def forward( FILE: flash_attn/losses/cross_entropy.py class CrossEntropyLoss (line 9) | class CrossEntropyLoss(nn.Module): method __init__ (line 10) | def __init__( method forward (line 47) | def forward(self, input, target, precomputed_lse=None): FILE: flash_attn/models/baichuan.py function remap_state_dict_hf_baichuan (line 17) | def remap_state_dict_hf_baichuan(state_dict, config): function baichuan_config_to_gpt2_config (line 115) | def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) ->... FILE: flash_attn/models/bert.py function create_mixer_cls (line 57) | def create_mixer_cls(config, cross_attn=False, return_residual=False): function create_mlp_cls (line 80) | def create_mlp_cls(config, layer_idx=None, return_residual=False): function create_block (line 116) | def create_block(config, layer_idx=None): function _init_weights (line 141) | def _init_weights(module, initializer_range=0.02): class BertEncoder (line 152) | class BertEncoder(nn.Module): method __init__ (line 153) | def __init__(self, config: BertConfig): method forward (line 160) | def forward(self, hidden_states, key_padding_mask=None, subset_mask=No... class BertPooler (line 215) | class BertPooler(nn.Module): method __init__ (line 216) | def __init__(self, config): method forward (line 225) | def forward(self, hidden_states, pool=True): class BertPredictionHeadTransform (line 234) | class BertPredictionHeadTransform(nn.Module): method __init__ (line 235) | def __init__(self, config): method forward (line 253) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertLMPredictionHead (line 265) | class BertLMPredictionHead(nn.Module): method __init__ (line 266) | def __init__(self, config): method forward (line 279) | def forward(self, hidden_states): class BertPreTrainingHeads (line 285) | class BertPreTrainingHeads(nn.Module): method __init__ (line 286) | def __init__(self, config): method forward (line 291) | def forward(self, sequence_output, pooled_output): class BertPreTrainedModel (line 297) | class BertPreTrainedModel(nn.Module): method __init__ (line 302) | def __init__(self, config, *inputs, **kwargs): method from_pretrained (line 315) | def from_pretrained(cls, model_name, config, *inputs, **kwargs): class BertModel (line 340) | class BertModel(BertPreTrainedModel): method __init__ (line 341) | def __init__(self, config: BertConfig, add_pooling_layer=True): method forward (line 367) | def forward( class BertForPreTraining (line 427) | class BertForPreTraining(BertPreTrainedModel): method __init__ (line 428) | def __init__(self, config: BertConfig): method tie_weights (line 456) | def tie_weights(self): method forward (line 459) | def forward( function remap_state_dict (line 524) | def remap_state_dict(state_dict, config: PretrainedConfig): function inv_remap_state_dict (line 637) | def inv_remap_state_dict(state_dict, config: PretrainedConfig): FILE: flash_attn/models/bigcode.py function remap_state_dict_hf_bigcode (line 10) | def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): function inv_remap_state_dict_hf_bigcode (line 112) | def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig): function bigcode_config_to_gpt2_config (line 206) | def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> G... FILE: flash_attn/models/btlm.py function remap_state_dict_hf_btlm (line 17) | def remap_state_dict_hf_btlm(state_dict, config): function btlm_config_to_gpt2_config (line 78) | def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Con... FILE: flash_attn/models/falcon.py function remap_state_dict_hf_falcon (line 13) | def remap_state_dict_hf_falcon(state_dict, config): function falcon_config_to_gpt2_config (line 106) | def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Con... FILE: flash_attn/models/gpt.py function create_mixer_cls (line 62) | def create_mixer_cls(config, layer_idx=None, process_group=None, device=... function create_mlp_cls (line 123) | def create_mlp_cls(config, layer_idx=None, process_group=None, device=No... function create_block (line 262) | def create_block(config, layer_idx=None, process_group=None, device=None... class GPTPreTrainedModel (line 311) | class GPTPreTrainedModel(nn.Module): method __init__ (line 316) | def __init__(self, config, *inputs, **kwargs): method from_pretrained (line 329) | def from_pretrained( function _init_weights (line 380) | def _init_weights( class GPTModel (line 409) | class GPTModel(GPTPreTrainedModel): method __init__ (line 410) | def __init__(self, config: GPT2Config, process_group=None, device=None... method tie_weights (line 504) | def tie_weights(self): method allocate_inference_cache (line 508) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 514) | def forward(self, input_ids, position_ids=None, inference_params=None): class GPTLMHeadModel (line 577) | class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): method __init__ (line 578) | def __init__(self, config: GPT2Config, process_group=None, device=None... method tie_weights (line 624) | def tie_weights(self): method allocate_inference_cache (line 630) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 635) | def forward(self, input_ids, position_ids=None, inference_params=None,... method load_state_dict (line 671) | def load_state_dict(self, state_dict, strict=True): function shard_state_dict_tp (line 698) | def shard_state_dict_tp(state_dict, config, world_size, rank): function combine_state_dicts_tp (line 814) | def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], c... function remap_state_dict_hf_gpt2 (line 930) | def remap_state_dict_hf_gpt2(state_dict, config): function remap_state_dict_megatron (line 987) | def remap_state_dict_megatron(state_dict, config): FILE: flash_attn/models/gpt_neox.py function remap_state_dict_hf_gpt_neox (line 13) | def remap_state_dict_hf_gpt_neox(state_dict, config): function gpt_neox_config_to_gpt2_config (line 101) | def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GP... FILE: flash_attn/models/gptj.py function remap_state_dict_hf_gptj (line 12) | def remap_state_dict_hf_gptj(state_dict, config): function gptj_config_to_gpt2_config (line 82) | def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: FILE: flash_attn/models/llama.py function remap_state_dict_meta_llama (line 19) | def remap_state_dict_meta_llama( function remap_state_dict_hf_llama (line 115) | def remap_state_dict_hf_llama( function inv_remap_state_dict_hf_llama (line 219) | def inv_remap_state_dict_hf_llama( function config_from_meta_checkpoint (line 329) | def config_from_meta_checkpoint( function config_from_hf_checkpoint (line 368) | def config_from_hf_checkpoint( function config_from_checkpoint (line 374) | def config_from_checkpoint( function state_dicts_from_checkpoint (line 383) | def state_dicts_from_checkpoint( function llama_config_to_gpt2_config (line 393) | def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: FILE: flash_attn/models/opt.py function remap_state_dict_hf_opt (line 12) | def remap_state_dict_hf_opt(state_dict, config): function opt_config_to_gpt2_config (line 90) | def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: FILE: flash_attn/models/vit.py function create_mixer_cls (line 28) | def create_mixer_cls( function create_mlp_cls (line 43) | def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): function create_block (line 52) | def create_block( class VisionTransformer (line 97) | class VisionTransformer(nn.Module): method __init__ (line 103) | def __init__( method init_weights (line 240) | def init_weights(self, mode=""): method _init_weights (line 247) | def _init_weights(self, m): method no_weight_decay (line 252) | def no_weight_decay(self): method _pos_embed (line 255) | def _pos_embed(self, x): method forward_features (line 270) | def forward_features(self, x, all_tokens=True): method forward_head (line 317) | def forward_head(self, x, pre_logits: bool = False): method forward (line 322) | def forward(self, x): method load_state_dict (line 327) | def load_state_dict(self, state_dict, strict=True): function init_weights_vit_timm (line 356) | def init_weights_vit_timm(module: nn.Module, name: str = ""): function vit_base_patch16_224 (line 366) | def vit_base_patch16_224(pretrained=False, **kwargs): FILE: flash_attn/modules/block.py class Block (line 21) | class Block(nn.Module): method __init__ (line 22) | def __init__( method allocate_inference_cache (line 105) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 108) | def forward( class ParallelBlock (line 259) | class ParallelBlock(nn.Module): method __init__ (line 264) | def __init__( method allocate_inference_cache (line 332) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 335) | def forward( FILE: flash_attn/modules/embedding.py class GPT2Embeddings (line 11) | class GPT2Embeddings(nn.Module): method __init__ (line 12) | def __init__( method forward (line 47) | def forward(self, input_ids, position_ids=None): class BertEmbeddings (line 64) | class BertEmbeddings(nn.Module): method __init__ (line 65) | def __init__( method forward (line 93) | def forward(self, input_ids, position_ids=None, token_type_ids=None): class VocabParallelEmbedding (line 114) | class VocabParallelEmbedding(nn.Embedding): method __init__ (line 115) | def __init__(self, num_embeddings, *args, process_group=None, padding_... method forward (line 130) | def forward(self, input: Tensor) -> Tensor: class ColumnParallelEmbedding (line 146) | class ColumnParallelEmbedding(nn.Embedding): method __init__ (line 147) | def __init__(self, num_embeddings, embedding_dim, *args, process_group... class ParallelGPT2Embeddings (line 161) | class ParallelGPT2Embeddings(nn.Module): method __init__ (line 162) | def __init__( method forward (line 193) | def forward(self, input_ids, position_ids=None, combine_batch_seqlen_d... FILE: flash_attn/modules/mha.py function get_alibi_slopes (line 37) | def get_alibi_slopes(nheads): class FlashSelfAttention (line 53) | class FlashSelfAttention(nn.Module): method __init__ (line 64) | def __init__( method forward (line 83) | def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): class FlashCrossAttention (line 133) | class FlashCrossAttention(nn.Module): method __init__ (line 144) | def __init__( method forward (line 163) | def forward( class SelfAttention (line 230) | class SelfAttention(nn.Module): method __init__ (line 241) | def __init__(self, causal=False, softmax_scale=None, attention_dropout... method forward (line 247) | def forward(self, qkv, causal=None, key_padding_mask=None): class CrossAttention (line 282) | class CrossAttention(nn.Module): method __init__ (line 293) | def __init__(self, causal=False, softmax_scale=None, attention_dropout... method forward (line 299) | def forward(self, q, kv, causal=None, key_padding_mask=None): function _update_kv_cache (line 344) | def _update_kv_cache(kv, inference_params, layer_idx): class MHA (line 373) | class MHA(nn.Module): method __init__ (line 376) | def __init__( method allocate_inference_cache (line 483) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): method _update_kv_cache (line 496) | def _update_kv_cache(self, kv, inference_params): method _apply_rotary_update_kvcache_attention (line 502) | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_para... method _update_kvcache_attention (line 542) | def _update_kvcache_attention(self, q, kv, inference_params): method forward (line 573) | def forward( class ParallelMHA (line 707) | class ParallelMHA(nn.Module): method __init__ (line 710) | def __init__( method allocate_inference_cache (line 824) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): method _update_kv_cache (line 837) | def _update_kv_cache(self, kv, inference_params): method _apply_rotary_update_kvcache_attention (line 842) | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_para... method _update_kvcache_attention (line 882) | def _update_kvcache_attention(self, q, kv, inference_params): method forward (line 910) | def forward(self, x, seqlen=None, inference_params=None, **kwargs): FILE: flash_attn/modules/mlp.py class Mlp (line 25) | class Mlp(nn.Module): method __init__ (line 26) | def __init__( method forward (line 47) | def forward(self, x): class ParallelMLP (line 54) | class ParallelMLP(nn.Module): method __init__ (line 55) | def __init__( method forward (line 92) | def forward(self, x): class GatedMlp (line 99) | class GatedMlp(nn.Module): method __init__ (line 100) | def __init__( method forward (line 125) | def forward(self, x): class ParallelGatedMlp (line 139) | class ParallelGatedMlp(nn.Module): method __init__ (line 142) | def __init__( method forward (line 183) | def forward(self, x): FILE: flash_attn/ops/activations.py function bias_gelu (line 16) | def bias_gelu(y, bias): function bias_gelu_back (line 25) | def bias_gelu_back(g, y, bias): class GeLUFunction (line 37) | class GeLUFunction(torch.autograd.Function): method forward (line 40) | def forward(ctx, input, bias): method backward (line 45) | def backward(ctx, grad_output): function gelu_fwd (line 57) | def gelu_fwd(x): function gelu_bwd (line 65) | def gelu_bwd(g, x): class FastGeLUFunction (line 74) | class FastGeLUFunction(torch.autograd.Function): method forward (line 77) | def forward(ctx, input): method backward (line 82) | def backward(ctx, grad_output): function relu_bwd (line 92) | def relu_bwd(g, x): function sqrelu_fwd (line 97) | def sqrelu_fwd(x): function sqrelu_bwd (line 103) | def sqrelu_bwd(g, x): class SwiGLUFunction (line 123) | class SwiGLUFunction(torch.autograd.Function): method forward (line 126) | def forward(ctx, x, y): method backward (line 131) | def backward(ctx, dout): FILE: flash_attn/ops/fused_dense.py class FusedDenseFunc (line 27) | class FusedDenseFunc(torch.autograd.Function): method forward (line 30) | def forward( method backward (line 71) | def backward(ctx, grad_output, *args): function fused_dense_func (line 118) | def fused_dense_func( class FusedDense (line 139) | class FusedDense(nn.Linear): method __init__ (line 140) | def __init__( method forward (line 152) | def forward(self, x, process_group=None): class ColumnParallelLinear (line 166) | class ColumnParallelLinear(nn.Linear): method __init__ (line 167) | def __init__( method forward (line 193) | def forward(self, x): class RowParallelLinear (line 206) | class RowParallelLinear(nn.Linear): method __init__ (line 207) | def __init__( method forward (line 239) | def forward(self, x): class FusedMLPFunc (line 249) | class FusedMLPFunc(torch.autograd.Function): method forward (line 252) | def forward( method backward (line 349) | def backward(ctx, grad_output, *args): function fused_mlp_func (line 475) | def fused_mlp_func( class FusedMLP (line 531) | class FusedMLP(nn.Module): method __init__ (line 532) | def __init__( method forward (line 580) | def forward(self, x, process_group=None): class ParallelFusedMLP (line 613) | class ParallelFusedMLP(nn.Module): method __init__ (line 614) | def __init__( method forward (line 664) | def forward(self, x): FILE: flash_attn/ops/layer_norm.py function maybe_align (line 9) | def maybe_align(x, alignment_in_bytes=16): function _dropout_add_layer_norm_forward (line 16) | def _dropout_add_layer_norm_forward( function _dropout_add_layer_norm_backward (line 55) | def _dropout_add_layer_norm_backward( function _dropout_add_layer_norm_subset_forward (line 110) | def _dropout_add_layer_norm_subset_forward( function _dropout_add_layer_norm_subset_backward (line 153) | def _dropout_add_layer_norm_subset_backward( function _dropout_add_layer_norm_parallel_residual_forward (line 212) | def _dropout_add_layer_norm_parallel_residual_forward( function _dropout_add_layer_norm_parallel_residual_backward (line 257) | def _dropout_add_layer_norm_parallel_residual_backward( class DropoutAddLayerNormFn (line 311) | class DropoutAddLayerNormFn(torch.autograd.Function): method forward (line 313) | def forward( method backward (line 374) | def backward(ctx, dz, *args): class DropoutAddLayerNormSubsetFn (line 416) | class DropoutAddLayerNormSubsetFn(torch.autograd.Function): method forward (line 418) | def forward( method backward (line 483) | def backward(ctx, dz, *args): class DropoutAddLayerNormParallelResidualFn (line 531) | class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): method forward (line 533) | def forward( method backward (line 605) | def backward(ctx, dz0, dz1, *args): function layer_norm (line 657) | def layer_norm(x, weight, bias, epsilon): function dropout_add_layer_norm (line 661) | def dropout_add_layer_norm( function dropout_add_layer_norm_subset (line 693) | def dropout_add_layer_norm_subset( function dropout_add_layer_norm_parallel_residual (line 731) | def dropout_add_layer_norm_parallel_residual( class DropoutAddLayerNorm (line 765) | class DropoutAddLayerNorm(torch.nn.Module): method __init__ (line 766) | def __init__( method reset_parameters (line 786) | def reset_parameters(self): method forward (line 790) | def forward(self, x0, residual=None): FILE: flash_attn/ops/rms_norm.py function rms_norm (line 14) | def rms_norm(x, weight, epsilon): function dropout_add_rms_norm (line 20) | def dropout_add_rms_norm( function dropout_add_rms_norm_subset (line 52) | def dropout_add_rms_norm_subset( function dropout_add_rms_norm_parallel_residual (line 90) | def dropout_add_rms_norm_parallel_residual( class RMSNorm (line 124) | class RMSNorm(torch.nn.Module): method __init__ (line 125) | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): method reset_parameters (line 133) | def reset_parameters(self): method forward (line 136) | def forward(self, x): class DropoutAddRMSNorm (line 140) | class DropoutAddRMSNorm(torch.nn.Module): method __init__ (line 141) | def __init__( method reset_parameters (line 161) | def reset_parameters(self): method forward (line 164) | def forward(self, x0, residual=None): FILE: flash_attn/ops/triton/cross_entropy.py function cross_entropy_fwd_kernel (line 25) | def cross_entropy_fwd_kernel( function cross_entropy_bwd_kernel (line 104) | def cross_entropy_bwd_kernel( class CrossEntropyLoss (line 149) | class CrossEntropyLoss(torch.autograd.Function): method forward (line 152) | def forward( method backward (line 258) | def backward(ctx, grad_losses, grad_z_losses): function cross_entropy_loss (line 292) | def cross_entropy_loss( FILE: flash_attn/ops/triton/k_activations.py class Activation (line 19) | class Activation(str, Enum): function get_triton_activation_kernel (line 27) | def get_triton_activation_kernel(activation: Optional[Activation]): function get_triton_activation_bwd_kernel (line 41) | def get_triton_activation_bwd_kernel(activation: Optional[Activation]): function tanh (line 56) | def tanh(x): function cosh (line 62) | def cosh(x): function relu (line 72) | def relu(x): function relu_grad (line 83) | def relu_grad(x): function squared_relu (line 93) | def squared_relu(x): function squared_relu_grad (line 104) | def squared_relu_grad(x): function leaky_relu (line 110) | def leaky_relu(x): function leaky_relu_grad (line 122) | def leaky_relu_grad(x): function gelu (line 133) | def gelu(x): function gelu_grad (line 139) | def gelu_grad(x): function gelu_approx (line 146) | def gelu_approx(x): function gelu_approx_grad (line 156) | def gelu_approx_grad(x): FILE: flash_attn/ops/triton/layer_norm.py function maybe_contiguous_lastdim (line 23) | def maybe_contiguous_lastdim(x): function maybe_contiguous (line 27) | def maybe_contiguous(x): function triton_autotune_configs (line 31) | def triton_autotune_configs(): function layer_norm_ref (line 44) | def layer_norm_ref( function rms_norm_ref (line 104) | def rms_norm_ref( function _layer_norm_fwd_1pass_kernel (line 174) | def _layer_norm_fwd_1pass_kernel( function _layer_norm_fwd (line 290) | def _layer_norm_fwd( function _layer_norm_fwd_impl (line 355) | def _layer_norm_fwd_impl( function _layer_norm_bwd_kernel (line 485) | def _layer_norm_bwd_kernel( function _layer_norm_bwd (line 643) | def _layer_norm_bwd( function _layer_norm_bwd_impl (line 702) | def _layer_norm_bwd_impl( class LayerNormFn (line 846) | class LayerNormFn(torch.autograd.Function): method forward (line 849) | def forward( method backward (line 951) | def backward(ctx, dy, *args): function layer_norm_fn (line 1010) | def layer_norm_fn( function rms_norm_fn (line 1052) | def rms_norm_fn( class RMSNorm (line 1093) | class RMSNorm(torch.nn.Module): method __init__ (line 1095) | def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered... method reset_parameters (line 1109) | def reset_parameters(self): method forward (line 1115) | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=Fa... class LayerNormLinearFn (line 1129) | class LayerNormLinearFn(torch.autograd.Function): method forward (line 1133) | def forward( method backward (line 1187) | def backward(ctx, dout, *args): function layer_norm_linear_fn (line 1229) | def layer_norm_linear_fn( FILE: flash_attn/ops/triton/linear.py function init_to_zero (line 22) | def init_to_zero(name): function get_configs_io_bound (line 26) | def get_configs_io_bound(): function kernel_fwd (line 131) | def kernel_fwd( function triton_linear_act (line 258) | def triton_linear_act( function kernel_bwd (line 428) | def kernel_bwd( function triton_dgrad_act (line 529) | def triton_dgrad_act( FILE: flash_attn/ops/triton/mlp.py class FusedDenseSqreluDenseFunc (line 13) | class FusedDenseSqreluDenseFunc(torch.autograd.Function): method forward (line 16) | def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): method backward (line 66) | def backward(ctx, grad_output): class FusedDenseSqreluDense (line 116) | class FusedDenseSqreluDense(nn.Module): method __init__ (line 117) | def __init__( method forward (line 145) | def forward(self, x): FILE: flash_attn/ops/triton/rotary.py function rotary_kernel (line 13) | def rotary_kernel( function apply_rotary (line 102) | def apply_rotary( FILE: flash_attn/utils/benchmark.py function benchmark_forward (line 8) | def benchmark_forward( function benchmark_backward (line 30) | def benchmark_backward( function benchmark_combined (line 72) | def benchmark_combined( function benchmark_fwd_bwd (line 117) | def benchmark_fwd_bwd( function benchmark_all (line 154) | def benchmark_all( function pytorch_profiler (line 202) | def pytorch_profiler( function benchmark_memory (line 258) | def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs): FILE: flash_attn/utils/distributed.py function all_gather_raw (line 18) | def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op... function reduce_scatter_raw (line 30) | def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, asyn... function all_reduce_raw (line 43) | def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op... class AllGatherFunc (line 49) | class AllGatherFunc(torch.autograd.Function): method forward (line 53) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 59) | def backward(ctx, grad_output: Tensor): class ReduceScatterFunc (line 68) | class ReduceScatterFunc(torch.autograd.Function): method forward (line 72) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 78) | def backward(ctx, grad_output: Tensor): class AllReduceFunc (line 87) | class AllReduceFunc(torch.autograd.Function): method forward (line 91) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 97) | def backward(ctx, grad_output: Tensor): function sync_shared_params (line 105) | def sync_shared_params(model: torch.nn.Module, process_group: ProcessGro... function allreduce_sequence_parallel_grad (line 120) | def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_gro... function get_dim_for_local_rank (line 135) | def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, m... FILE: flash_attn/utils/generation.py class InferenceParams (line 24) | class InferenceParams: method reset (line 35) | def reset(self, max_seqlen, max_batch_size): function modify_logits_for_top_k_filtering (line 45) | def modify_logits_for_top_k_filtering(logits, top_k): function modify_logits_for_top_p_filtering (line 53) | def modify_logits_for_top_p_filtering(logits, top_p): function sample (line 69) | def sample(logits, top_k=1, top_p=0.0, temperature=1.0): function decode (line 99) | def decode( function sample_speculative (line 209) | def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_... function decode_speculative (line 269) | def decode_speculative( class GenerationMixin (line 566) | class GenerationMixin: method allocate_inference_cache (line 567) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method generate (line 570) | def generate( function allocate_inference_cache (line 589) | def allocate_inference_cache( class DecodingCGCache (line 606) | class DecodingCGCache: function update_graph_cache (line 618) | def update_graph_cache( function capture_graph (line 693) | def capture_graph( FILE: flash_attn/utils/library.py function triton_op (line 10) | def triton_op( FILE: flash_attn/utils/pretrained.py function state_dict_from_pretrained (line 15) | def state_dict_from_pretrained(model_name, device=None, dtype=None): FILE: flash_attn/utils/testing.py function generate_random_padding_mask (line 11) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 34) | def generate_qkv( function construct_local_mask (line 159) | def construct_local_mask( function construct_chunk_mask (line 195) | def construct_chunk_mask( function attention_ref (line 228) | def attention_ref( FILE: flash_attn/utils/torch.py function custom_amp_decorator (line 5) | def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): FILE: hopper/benchmark_attn.py function time_fwd (line 41) | def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): function flops (line 62) | def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=... function convert_to_cudnn_type (line 76) | def convert_to_cudnn_type(torch_type): function cudnn_spda_setup (line 91) | def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1): function cudnn_spda_bwd_setup (line 146) | def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_l... FILE: hopper/benchmark_flash_attention_fp8.py function convert_to_cudnn_type (line 34) | def convert_to_cudnn_type(torch_type): function cudnn_spda_setup (line 52) | def cudnn_spda_setup(qkv, seqlen_q, seqlen_k, causal=False): function attention_pytorch (line 173) | def attention_pytorch(qkv, dropout_p=0.0, causal=True): function flops (line 201) | def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"): function efficiency (line 206) | def efficiency(flop, time): function time_fwd (line 209) | def time_fwd(func, *args, **kwargs): FILE: hopper/benchmark_split_kv.py function round_up_to_power_of_2 (line 10) | def round_up_to_power_of_2(x): function timeit (line 15) | def timeit(fn, *args, **kwargs): function main (line 35) | def main(): FILE: hopper/block.h function namespace (line 7) | namespace flash { function CUTLASS_DEVICE (line 104) | static function CUTLASS_DEVICE (line 121) | static FILE: hopper/copy_sm90_bulk_reduce.hpp type cute (line 9) | namespace cute type SM90_BULK_REDUCE_ADD (line 14) | struct SM90_BULK_REDUCE_ADD method CUTE_HOST_DEVICE (line 16) | CUTE_HOST_DEVICE static void method CUTE_HOST_DEVICE (line 31) | CUTE_HOST_DEVICE static void FILE: hopper/epilogue_bwd.hpp type flash (line 17) | namespace flash { type CollectiveEpilogueBwd (line 23) | struct CollectiveEpilogueBwd { type TensorStorage (line 85) | struct TensorStorage : cute::aligned_struct { type Arguments (line 105) | struct Arguments { type Params (line 121) | struct Params { method Params (line 133) | static Params method CUTLASS_DEVICE (line 156) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 165) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 273) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 279) | CUTLASS_DEVICE void type CollectiveEpilogueBwdGQA (line 323) | struct CollectiveEpilogueBwdGQA { type TensorStorageTMA (line 354) | struct TensorStorageTMA : cute::aligned_struct { type TensorStorageSTG (line 357) | struct TensorStorageSTG { type Arguments (line 366) | struct Arguments { type Params (line 382) | struct Params { method Params (line 397) | static Params method CUTLASS_DEVICE (line 410) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 415) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 520) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 525) | CUTLASS_DEVICE void FILE: hopper/epilogue_fwd.hpp type flash (line 19) | namespace flash { type CollectiveEpilogueFwd (line 25) | struct CollectiveEpilogueFwd { type TensorStorage (line 106) | struct TensorStorage : cute::aligned_struct<128> { type Arguments (line 122) | struct Arguments { type Params (line 138) | struct Params { method Params (line 160) | static Params method CUTLASS_DEVICE (line 206) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 214) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 404) | CUTLASS_DEVICE void method CUTLASS_DEVICE (line 410) | CUTLASS_DEVICE void FILE: hopper/flash.h type Qkv_params (line 12) | struct Qkv_params { function Qkv_params (line 37) | struct Flash_fwd_params : public Qkv_params { function Flash_fwd_params (line 172) | struct Flash_bwd_params : public Flash_fwd_params { FILE: hopper/flash_api.cpp function PyObject (line 24) | PyObject* PyInit__C(void) function make_cuda_guard_from_tensor (line 45) | inline at::cuda::CUDAGuard make_cuda_guard_from_tensor(const at::Tensor&... function set_params_fprop (line 50) | void set_params_fprop(Flash_fwd_params ¶ms, function set_params_dgrad (line 170) | void set_params_dgrad(Flash_bwd_params ¶ms, function run_mha_fwd_constexpr (line 256) | void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { function run_mha_fwd (line 367) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { function run_mha_fwd_combine (line 387) | void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, ... function get_pagedkv_tma (line 415) | inline bool get_pagedkv_tma(Flash_fwd_params const& params) { function get_pack_gqa (line 426) | inline bool get_pack_gqa(Flash_fwd_params const& params) { function get_num_splits (line 442) | inline int get_num_splits(Flash_fwd_params const& params) { function get_max_headdim (line 473) | inline int get_max_headdim() { function round_up_headdim (line 492) | inline int round_up_headdim(int head_size) { function round_up_headdimv (line 511) | inline int round_up_headdimv(int head_size) { function mha_fwd_get_scheduler_metadata (line 521) | at::Tensor function mha_fwd (line 672) | std::tuple function run_mha_bwd (line 1201) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { function run_mha_bwd_constexpr (line 1206) | void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { function run_mha_bwd (line 1246) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { function mha_bwd (line 1267) | std::tuple m... function mha_combine (line 1570) | std::tuple function TORCH_LIBRARY (line 1673) | TORCH_LIBRARY(flash_attn_3, m) { function TORCH_LIBRARY_IMPL (line 1764) | TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { FILE: hopper/flash_api_stable.cpp function make_device_guard (line 36) | inline tsa::DeviceGuard make_device_guard(const Tensor& t) { function initVectors (line 42) | void initVectors() { function initDeviceProperty (line 56) | void initDeviceProperty(int device_index) { function cudaDeviceProp (line 67) | cudaDeviceProp* get_device_prop() { function PyObject (line 87) | PyObject* PyInit__C(void) function set_params_fprop (line 114) | void set_params_fprop(Flash_fwd_params ¶ms, function set_params_dgrad (line 235) | void set_params_dgrad(Flash_bwd_params ¶ms, function run_mha_fwd_constexpr (line 321) | void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { function run_mha_fwd (line 432) | void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { function run_mha_fwd_combine (line 452) | void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, ... function get_pagedkv_tma (line 480) | inline bool get_pagedkv_tma(Flash_fwd_params const& params) { function get_pack_gqa (line 491) | inline bool get_pack_gqa(Flash_fwd_params const& params) { function get_num_splits (line 507) | inline int get_num_splits(Flash_fwd_params const& params) { function get_max_headdim (line 538) | inline int get_max_headdim() { function round_up_headdim (line 557) | inline int round_up_headdim(int head_size) { function round_up_headdimv (line 576) | inline int round_up_headdimv(int head_size) { function Tensor (line 586) | Tensor function mha_fwd (line 741) | std::tuple function run_mha_bwd (line 1272) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { function run_mha_bwd_constexpr (line 1277) | void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) { function run_mha_bwd (line 1317) | void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { function mha_bwd (line 1338) | std::tuple mha_bwd( function mha_combine (line 1648) | std::tuple function boxed_mha_fwd (line 1755) | void boxed_mha_fwd( function boxed_mha_bwd (line 1804) | void boxed_mha_bwd( function boxed_mha_combine (line 1841) | void boxed_mha_combine( function boxed_mha_fwd_get_scheduler_metadata (line 1857) | void boxed_mha_fwd_get_scheduler_metadata( function STABLE_TORCH_LIBRARY (line 1892) | STABLE_TORCH_LIBRARY(flash_attn_3, m) { function STABLE_TORCH_LIBRARY_IMPL (line 1983) | STABLE_TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) { FILE: hopper/flash_attn_interface.py function maybe_contiguous (line 30) | def maybe_contiguous(x): function round_multiple (line 34) | def round_multiple(x, m): function round_up_headdim (line 38) | def round_up_headdim(head_size: int) -> int: function _flash_attn_forward (line 60) | def _flash_attn_forward( function _flash_attn_forward_fake (line 154) | def _flash_attn_forward_fake( function _flash_attn_backward (line 259) | def _flash_attn_backward( function _flash_attn_backward_fake (line 313) | def _flash_attn_backward_fake( function setup_context (line 410) | def setup_context(ctx, inputs, output): function _backward (line 422) | def _backward(ctx, dout, *grads): class FlashAttnQKVPackedFunc (line 453) | class FlashAttnQKVPackedFunc(torch.autograd.Function): method forward (line 455) | def forward( method backward (line 514) | def backward(ctx, dout, *args): class FlashAttnFunc (line 552) | class FlashAttnFunc(torch.autograd.Function): method forward (line 555) | def forward( method backward (line 611) | def backward(ctx, dout, *args): class FlashAttnVarlenFunc (line 642) | class FlashAttnVarlenFunc(torch.autograd.Function): method forward (line 645) | def forward( method backward (line 713) | def backward(ctx, dout, *args): function flash_attn_qkvpacked_func (line 747) | def flash_attn_qkvpacked_func( function flash_attn_func (line 809) | def flash_attn_func( function flash_attn_varlen_func (line 890) | def flash_attn_varlen_func( function flash_attn_combine (line 938) | def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): function flash_attn_with_kvcache (line 942) | def flash_attn_with_kvcache( function get_scheduler_metadata (line 1106) | def get_scheduler_metadata( FILE: hopper/flash_bwd_kernel_sm80.h function namespace (line 16) | namespace flash { type Params (line 80) | struct Params { function EpilogueParams (line 82) | EpilogueParams epilogue{} function TileSchedulerParams (line 84) | TileSchedulerParams scheduler{} function Params (line 92) | static FILE: hopper/flash_bwd_kernel_sm90.h function namespace (line 20) | namespace flash { type Params (line 100) | struct Params { function EpilogueParams (line 102) | EpilogueParams epilogue{} function TileSchedulerParams (line 104) | TileSchedulerParams scheduler{} function Params (line 112) | static FILE: hopper/flash_bwd_launch_template.h function dim3 (line 74) | dim3 grid_m(num_m_block, params.h, params.b); function typename (line 230) | typename PostprocessKernel::Arguments postprocess_args { FILE: hopper/flash_bwd_postprocess_kernel.h function namespace (line 18) | namespace flash { FILE: hopper/flash_bwd_preprocess_kernel.h function namespace (line 17) | namespace flash { FILE: hopper/flash_fwd_combine_kernel.h function namespace (line 20) | namespace flash { FILE: hopper/flash_fwd_combine_launch_template.h function typename (line 28) | typename CombineKernel::Arguments args { function run_mha_fwd_combine_ (line 56) | void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream,... FILE: hopper/flash_fwd_kernel_sm80.h function namespace (line 18) | namespace flash { type Arguments (line 91) | struct Arguments { function EpilogueArguments (line 93) | EpilogueArguments epilogue{} function TileSchedulerArguments (line 95) | TileSchedulerArguments scheduler{} type Params (line 99) | struct Params { function EpilogueParams (line 101) | EpilogueParams epilogue{} function TileSchedulerParams (line 103) | TileSchedulerParams scheduler{} function Params (line 111) | static FILE: hopper/flash_fwd_kernel_sm90.h function namespace (line 23) | namespace flash { type PipelineStorage (line 105) | struct PipelineStorage type Arguments (line 122) | struct Arguments { function EpilogueArguments (line 124) | EpilogueArguments epilogue{} function TileSchedulerArguments (line 126) | TileSchedulerArguments scheduler{} type Params (line 130) | struct Params { function EpilogueParams (line 132) | EpilogueParams epilogue{} function TileSchedulerParams (line 134) | TileSchedulerParams scheduler{} function dim3 (line 167) | static dim3 function dim3 (line 172) | static dim3 function CUTLASS_DEVICE (line 177) | CUTLASS_DEVICE FILE: hopper/flash_fwd_launch_template.h function typename (line 93) | typename CollectiveMainloop::Arguments mainloop_args { FILE: hopper/generate_kernels.py class Kernel (line 84) | class Kernel: method template (line 96) | def template(self) -> str: method filename (line 127) | def filename(self) -> str: function get_all_kernels (line 131) | def get_all_kernels() -> List[Kernel]: function batch_hdim (line 148) | def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: function batch_softcap (line 166) | def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: function write_kernel (line 187) | def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: function main (line 195) | def main(output_dir: Optional[str]) -> None: FILE: hopper/heuristics.h function should_pack_gqa (line 9) | inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_k... function num_splits_heuristic (line 25) | inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_... FILE: hopper/mainloop_bwd_sm80.hpp type flash (line 20) | namespace flash { type CollectiveMainloopBwdSm80 (line 29) | struct CollectiveMainloopBwdSm80 { type TensorStorageSharedQV (line 252) | struct TensorStorageSharedQV : cute::aligned_struct<128> { type TensorStorageSeparateQV (line 265) | struct TensorStorageSeparateQV : cute::aligned_struct<128> { type Arguments (line 279) | struct Arguments { type Params (line 312) | struct Params { method Params (line 346) | static Params method CUTLASS_DEVICE (line 377) | CUTLASS_DEVICE bool FILE: hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp type flash (line 26) | namespace flash { type CollectiveMainloopBwdSm90 (line 35) | struct CollectiveMainloopBwdSm90 { type TensorStorage (line 280) | struct TensorStorage : cute::aligned_struct { type TensorStorageSeparateQV (line 167) | struct TensorStorageSeparateQV : cute::aligned_struct<128> { type Arguments (line 176) | struct Arguments { type Params (line 219) | struct Params { method Params (line 264) | static Params method CUTLASS_DEVICE (line 308) | CUTLASS_DEVICE bool method CUTLASS_DEVICE (line 660) | CUTLASS_DEVICE bool FILE: hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp type flash (line 27) | namespace flash { type CollectiveMainloopFwdSm90 (line 34) | struct CollectiveMainloopFwdSm90 { type TensorStorageWithoutPNoTranspose (line 308) | struct TensorStorageWithoutPNoTranspose : cute::aligned_struct None: function check_env_flag (line 347) | def check_env_flag(name: str, default: str = "") -> bool: function is_offline_build (line 352) | def is_offline_build() -> bool: function get_flashattn_cache_path (line 369) | def get_flashattn_cache_path(): function open_url (line 378) | def open_url(url): function download_and_copy (line 388) | def download_and_copy(name, src_func, dst_path, version, url_func): function nvcc_threads_args (line 415) | def nvcc_threads_args(): function get_package_version (line 638) | def get_package_version(): function get_wheel_url (line 649) | def get_wheel_url(): class CachedWheelsCommand (line 672) | class CachedWheelsCommand(_bdist_wheel): method run (line 680) | def run(self): FILE: hopper/sm90_pipeline_no_cluster.hpp class PipelineTmaAsyncNoCluster (line 22) | class PipelineTmaAsyncNoCluster: public Base { method CUTLASS_DEVICE (line 33) | static method if (line 62) | if constexpr (cute::is_same_v) { FILE: hopper/softmax.h function namespace (line 15) | namespace flash { function CUTLASS_DEVICE (line 99) | CUTLASS_DEVICE Softmax(float const softmax_scale_log2_) : softmax_scale_... function online_softmax (line 127) | void online_softmax(Tensor0 &acc_s) { FILE: hopper/test_attn_kvcache.py function construct_local_mask (line 10) | def construct_local_mask( function attention_ref (line 45) | def attention_ref( function test_flash_attn_kvcache_nosplit (line 155) | def test_flash_attn_kvcache_nosplit(nheads_kv, gqa_ratio, num_requests, ... function test_flash_attn_kvcache_nosplit_fp8 (line 217) | def test_flash_attn_kvcache_nosplit_fp8(nheads_kv, gqa_ratio, num_reques... function test_flash_attn_kvcache_output (line 292) | def test_flash_attn_kvcache_output(nheads_kv, gqa_ratio, num_requests, q... function test_flash_attn_kvcache_output_fp8 (line 399) | def test_flash_attn_kvcache_output_fp8(nheads_kv, gqa_ratio, num_request... FILE: hopper/test_flash_attn.py function should_test_backward (line 58) | def should_test_backward(args, kwargs): function should_run_schema_check (line 80) | def should_run_schema_check(args, kwargs): function should_run_fake_check (line 87) | def should_run_fake_check(args, kwargs): function run_opcheck (line 93) | def run_opcheck(fn): function test_flash_attn_output (line 167) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 404) | def test_flash_attn_varlen_output( function test_flash_attn_kvcache (line 715) | def test_flash_attn_kvcache( function _generate_block_kvcache (line 1054) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d... function test_flash_attn_cluster (line 1090) | def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): function test_flash_attn_race_condition (line 1133) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): function attention_combine_ref (line 1167) | def attention_combine_ref(out_partial, lse_partial): function test_flash_attn_combine (line 1190) | def test_flash_attn_combine(num_splits, seqlen, d, dtype): function test_flash3_bw_compatibility (line 1225) | def test_flash3_bw_compatibility() -> None: FILE: hopper/test_flash_attn_bwd_determinism.py function test_flash_attn_output (line 110) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 391) | def test_flash_attn_varlen_output( FILE: hopper/test_flash_attn_triton_amd.py function test_flash_attn_output (line 105) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 334) | def test_flash_attn_varlen_output( function test_flash_attn_kvcache (line 628) | def test_flash_attn_kvcache( function _generate_block_kvcache (line 962) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d... function test_flash_attn_cluster (line 998) | def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): function test_flash_attn_race_condition (line 1042) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): function attention_combine_ref (line 1076) | def attention_combine_ref(out_partial, lse_partial): function test_flash_attn_combine (line 1099) | def test_flash_attn_combine(num_splits, seqlen, d, dtype): function test_flash3_bw_compatibility (line 1135) | def test_flash3_bw_compatibility() -> None: FILE: hopper/test_kvcache.py function benchmark_fa_kv_old (line 20) | def benchmark_fa_kv_old(fn, repeats=10, desc='', verbose=True, **kwinputs): function benchmark_fa_kv (line 34) | def benchmark_fa_kv(fn, repeats=10, *args, **kwargs): function main (line 47) | def main(): FILE: hopper/test_torch_compile_and_export.py class EfficienctMultiHeadAttention (line 6) | class EfficienctMultiHeadAttention(nn.Module): method __init__ (line 7) | def __init__(self, embed_size, num_heads, dropout=0.0, use_flash_attn=... method forward (line 20) | def forward(self, x, attention_mask=None): function create_model (line 39) | def create_model(batch_size=16, sequence_length=256, embedding_dim=2048,... function test_export_model (line 45) | def test_export_model(): function test_compile_and_package_model (line 61) | def test_compile_and_package_model(): FILE: hopper/test_util.py function generate_random_padding_mask (line 9) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 32) | def generate_qkv( function construct_local_mask (line 157) | def construct_local_mask( function construct_chunk_mask (line 193) | def construct_chunk_mask( function attention_ref (line 226) | def attention_ref( FILE: hopper/tile_scheduler.hpp type TileSchedulerArguments (line 18) | struct TileSchedulerArguments { class SingleTileScheduler (line 37) | class SingleTileScheduler { type Params (line 44) | struct Params { method Params (line 54) | static Params method dim3 (line 65) | static dim3 type WorkTileInfo (line 70) | struct WorkTileInfo { method CUTLASS_DEVICE (line 76) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 90) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 94) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 121) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 125) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 130) | CUTLASS_DEVICE class StaticPersistentTileScheduler (line 141) | class StaticPersistentTileScheduler { type Params (line 148) | struct Params { method Params (line 154) | static Params method dim3 (line 161) | static dim3 type WorkTileInfo (line 166) | struct WorkTileInfo { method CUTLASS_DEVICE (line 169) | CUTLASS_DEVICE method if (line 181) | if constexpr (Split) { method CUTLASS_DEVICE (line 193) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 199) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 203) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 208) | CUTLASS_DEVICE class DynamicPersistentTileScheduler (line 220) | class DynamicPersistentTileScheduler { type Params (line 243) | struct Params { method Params (line 252) | static Params method dim3 (line 277) | static dim3 type WorkTileInfo (line 282) | struct WorkTileInfo { method CUTLASS_DEVICE (line 285) | CUTLASS_DEVICE method if (line 299) | if (bidhb < params.num_hb_quotient) { method if (line 306) | if constexpr (Split) { method CUTLASS_DEVICE (line 320) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 326) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 334) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 343) | CUTLASS_DEVICE class SingleTileBwdLPTScheduler (line 368) | class SingleTileBwdLPTScheduler { type Params (line 375) | struct Params { method Params (line 386) | static Params method dim3 (line 412) | static dim3 type WorkTileInfo (line 417) | struct WorkTileInfo { method CUTLASS_DEVICE (line 422) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 436) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 440) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 472) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 476) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 481) | CUTLASS_DEVICE class VarlenDynamicPersistentTileScheduler (line 493) | class VarlenDynamicPersistentTileScheduler { type Params (line 507) | struct Params { method Params (line 524) | static Params method dim3 (line 547) | static dim3 type WorkTileInfo (line 552) | struct WorkTileInfo { method CUTLASS_DEVICE (line 555) | CUTLASS_DEVICE method if (line 572) | if constexpr (!Split) { method else (line 574) | else { method tile_idx_to_work_tile (line 597) | tile_idx_to_work_tile(Params const& params, int next_tile_idx, WorkT... method CUTLASS_DEVICE (line 761) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 776) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 782) | CUTLASS_DEVICE method CUTLASS_DEVICE (line 791) | CUTLASS_DEVICE FILE: hopper/tile_size.h function else (line 31) | else if (headdim <= 128) { function else (line 36) | else if (headdim <= 192) { FILE: hopper/utils.h function namespace (line 26) | namespace flash { function CUTLASS_DEVICE (line 675) | CUTLASS_DEVICE FILE: setup.py function cuda_archs (line 72) | def cuda_archs() -> str: function get_platform (line 76) | def get_platform(): function get_cuda_bare_metal_version (line 91) | def get_cuda_bare_metal_version(cuda_dir): function add_cuda_gencodes (line 100) | def add_cuda_gencodes(cc_flag, archs, bare_metal_version): function get_hip_version (line 153) | def get_hip_version(): function check_if_cuda_home_none (line 157) | def check_if_cuda_home_none(global_option: str) -> None: function check_if_rocm_home_none (line 169) | def check_if_rocm_home_none(global_option: str) -> None: function detect_hipify_v2 (line 179) | def detect_hipify_v2(): function append_nvcc_threads (line 191) | def append_nvcc_threads(nvcc_extra_args): function rename_cpp_to_cu (line 195) | def rename_cpp_to_cu(cpp_files): function validate_and_update_archs (line 200) | def validate_and_update_archs(archs): function get_package_version (line 508) | def get_package_version(): function get_wheel_url (line 519) | def get_wheel_url(): class CachedWheelsCommand (line 550) | class CachedWheelsCommand(_bdist_wheel): method run (line 558) | def run(self): class NinjaBuildExtension (line 585) | class NinjaBuildExtension(BuildExtension): method __init__ (line 586) | def __init__(self, *args, **kwargs) -> None: FILE: tests/cute/benchmark_block_sparsity.py class BenchmarkConfig (line 32) | class BenchmarkConfig: class BenchmarkResult (line 47) | class BenchmarkResult: function benchmark_pytorch_block_sparsity (line 56) | def benchmark_pytorch_block_sparsity( function benchmark_cute_block_sparsity (line 91) | def benchmark_cute_block_sparsity( function run_benchmark (line 195) | def run_benchmark( function generate_configs (line 220) | def generate_configs( function print_results (line 243) | def print_results(results: List[BenchmarkResult]): function main (line 296) | def main(): FILE: tests/cute/benchmark_mask_mod.py class BenchmarkConfig (line 30) | class BenchmarkConfig: class FlashAttentionBenchmark (line 83) | class FlashAttentionBenchmark: method __init__ (line 84) | def __init__(self, config: BenchmarkConfig): method _validate_config (line 112) | def _validate_config(self): method _generate_varlen_seqlens (line 139) | def _generate_varlen_seqlens(self, min_len: int, max_len: int) -> Tupl... method _create_tensors (line 154) | def _create_tensors(self) -> Dict[str, torch.Tensor]: method _compile_kernel (line 307) | def _compile_kernel(self, tensors: Dict[str, torch.Tensor]) -> Tuple[A... method _calculate_flops (line 448) | def _calculate_flops(self, tensors: Dict[str, torch.Tensor]) -> float: method benchmark (line 525) | def benchmark(self) -> Dict[str, Any]: method _print_results (line 604) | def _print_results(self, results: Dict[str, Any]): FILE: tests/cute/conftest.py function _get_gpu_ids (line 11) | def _get_gpu_ids(): function pytest_configure (line 32) | def pytest_configure(config): function pytest_collection_finish (line 61) | def pytest_collection_finish(session): FILE: tests/cute/mask_mod_definitions.py function cute_causal_mask (line 26) | def cute_causal_mask( function get_cute_causal_mask (line 39) | def get_cute_causal_mask(offset: int): function get_cute_block_causal_mask (line 43) | def get_cute_block_causal_mask(offset: int): function get_cute_sliding_window_mask (line 60) | def get_cute_sliding_window_mask(window_left: int, window_right: int, of... function cute_block_diagonal_mask (line 85) | def cute_block_diagonal_mask( function cute_mini_causal_mask (line 98) | def cute_mini_causal_mask( function cute_prefix_lm_mask (line 114) | def cute_prefix_lm_mask( function cute_dilated_sliding_window_mask (line 130) | def cute_dilated_sliding_window_mask( function cute_document_mask (line 148) | def cute_document_mask( function cute_ima_mask (line 164) | def cute_ima_mask( function get_flex_causal_mask (line 191) | def get_flex_causal_mask(offset: int): function get_flex_block_causal_mask (line 198) | def get_flex_block_causal_mask(offset: int): function get_flex_sliding_window_mask (line 205) | def get_flex_sliding_window_mask(window_left: int, window_right: int, of... function flex_block_diagonal_mask (line 215) | def flex_block_diagonal_mask(b, h, q_idx, kv_idx): function flex_mini_causal_mask (line 220) | def flex_mini_causal_mask(b, h, q_idx, kv_idx): function flex_prefix_lm_mask (line 224) | def flex_prefix_lm_mask(b, h, q_idx, kv_idx): function flex_dilated_sliding_window_mask (line 232) | def flex_dilated_sliding_window_mask(b, h, q_idx, kv_idx): function flex_document_mask (line 241) | def flex_document_mask(b, h, q_idx, kv_idx, doc_id): function flex_ima_mask (line 245) | def flex_ima_mask(b, h, q_idx, kv_idx, bias): function random_doc_id_tensor (line 254) | def random_doc_id_tensor(nheads, batch, seqlen_q, device="cpu"): function get_mask_pair (line 298) | def get_mask_pair(mask_name, seqlen_q=None, seqlen_k=None, window_size=N... FILE: tests/cute/score_mod_definitions.py function score_mod_identity (line 14) | def score_mod_identity(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_inf... function score_mod_identity_vectorized (line 19) | def score_mod_identity_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx,... function score_mod_causal (line 24) | def score_mod_causal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info,... function score_mod_causal_vectorized (line 30) | def score_mod_causal_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, s... function score_mod_rel_bias (line 41) | def score_mod_rel_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_inf... function score_mod_rel_bias_vectorized (line 48) | def score_mod_rel_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx,... function score_mod_rel_bias_x2 (line 60) | def score_mod_rel_bias_x2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_... function score_mod_rel_bias_x2_vectorized (line 68) | def score_mod_rel_bias_x2_vectorized( function score_mod_times_two (line 82) | def score_mod_times_two(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_in... function score_mod_alibi (line 88) | def score_mod_alibi(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, ... function score_mod_alibi_vectorized (line 100) | def score_mod_alibi_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, se... function score_mod_sliding_window (line 116) | def score_mod_sliding_window(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seql... function score_mod_block_diagonal (line 124) | def score_mod_block_diagonal(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seql... function score_mod_causal_v2 (line 132) | def score_mod_causal_v2(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_in... function score_mod_batch_bias (line 139) | def score_mod_batch_bias(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_i... function score_mod_batch_bias_vectorized (line 150) | def score_mod_batch_bias_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_id... function score_mod_dual_buffer (line 161) | def score_mod_dual_buffer(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_... function score_mod_dual_buffer_vectorized (line 181) | def score_mod_dual_buffer_vectorized(tSrS_ssa, b_idx, h_idx, q_idx, kv_i... function score_mod_global_kv_bias (line 205) | def score_mod_global_kv_bias( function score_mod_global_q_bias (line 222) | def score_mod_global_q_bias( function score_mod_global_rel_plus_kv_bias (line 238) | def score_mod_global_rel_plus_kv_bias( function score_mod_global_q_and_kv_bias (line 260) | def score_mod_global_q_and_kv_bias( function score_mod_global_logical_rel_plus_kv_bias (line 290) | def score_mod_global_logical_rel_plus_kv_bias( function score_mod_stress_complex_arithmetic (line 314) | def score_mod_stress_complex_arithmetic( function score_mod_stress_conditional_mask (line 342) | def score_mod_stress_conditional_mask( function score_mod_stress_multi_buffer (line 372) | def score_mod_stress_multi_buffer( function score_mod_stress_global_offset (line 431) | def score_mod_stress_global_offset( function score_mod_stress_xor_pattern (line 449) | def score_mod_stress_xor_pattern( function score_mod_debug_global_idx (line 475) | def score_mod_debug_global_idx( function identity_eager (line 490) | def identity_eager(score, b, h, q_idx, kv_idx): function causal_eager (line 494) | def causal_eager(score, b, h, q_idx, kv_idx): function rel_bias_eager (line 498) | def rel_bias_eager(score, b, h, q_idx, kv_idx): function rel_bias_x2_eager (line 502) | def rel_bias_x2_eager(score, b, h, q_idx, kv_idx): function times_two_eager (line 506) | def times_two_eager(score, b, h, q_idx, kv_idx): function alibi_eager (line 510) | def alibi_eager(score, b, h, q_idx, kv_idx): function sliding_window_eager (line 515) | def sliding_window_eager(score, b, h, q_idx, kv_idx): function block_diagonal_eager (line 519) | def block_diagonal_eager(score, b, h, q_idx, kv_idx): function causal_v2_eager (line 523) | def causal_v2_eager(score, b, h, q_idx, kv_idx): function batch_bias_factory (line 527) | def batch_bias_factory(bias_tensor): function dual_buffer_factory (line 534) | def dual_buffer_factory(head_bias, pos_bias): function packed_kv_bias_factory (line 541) | def packed_kv_bias_factory(bias_tensor, cu_seqlens_k): function packed_q_bias_factory (line 554) | def packed_q_bias_factory(bias_tensor, cu_seqlens_q): function packed_rel_plus_kv_bias_factory (line 566) | def packed_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): function packed_q_and_kv_bias_factory (line 580) | def packed_q_and_kv_bias_factory(q_bias, kv_bias, cu_seqlens_q, cu_seqle... function packed_logical_rel_plus_kv_bias_factory (line 597) | def packed_logical_rel_plus_kv_bias_factory(bias_tensor, cu_seqlens_k): function stress_complex_arithmetic_factory (line 605) | def stress_complex_arithmetic_factory(bias, cu_seqlens_q): function stress_conditional_mask_factory (line 618) | def stress_conditional_mask_factory(token_bias, cu_seqlens_q, cu_seqlens... function stress_multi_buffer_factory (line 632) | def stress_multi_buffer_factory( function stress_global_offset_factory (line 654) | def stress_global_offset_factory(token_bias, cu_seqlens_k): function stress_xor_pattern_factory (line 661) | def stress_xor_pattern_factory(token_bias, cu_seqlens_q, cu_seqlens_k): function debug_global_idx_factory (line 670) | def debug_global_idx_factory(bias, cu_seqlens_k): FILE: tests/cute/test_block_sparsity.py function _call_compute_block_sparsity (line 11) | def _call_compute_block_sparsity( function _compare_block_sparsity (line 43) | def _compare_block_sparsity( function test_fixed_length_masks (line 213) | def test_fixed_length_masks( function test_parameterized_masks (line 292) | def test_parameterized_masks( function test_edge_cases (line 364) | def test_edge_cases(seqlen_q, seqlen_k, tile_m, tile_n): function test_fast_sampling (line 426) | def test_fast_sampling(seqlen_q, seqlen_k, tile_m, tile_n, nheads, mask_... FILE: tests/cute/test_flash_attn.py function test_flash_attn_output (line 100) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 453) | def test_flash_attn_varlen_output( function test_flash_attn_kvcache (line 921) | def test_flash_attn_kvcache( function test_flash_attn_bwd_preallocated_outputs (line 1430) | def test_flash_attn_bwd_preallocated_outputs(seqlen_q, seqlen_k, d, caus... function test_flash_attn_lse_grad (line 1469) | def test_flash_attn_lse_grad(seqlen_q, seqlen_k, d, causal, dtype): function test_flash_attn_lse_grad_unused (line 1549) | def test_flash_attn_lse_grad_unused(seqlen_q, seqlen_k, d, causal, dtype): function _generate_block_kvcache (line 1599) | def _generate_block_kvcache( function test_flash_attn_paged_deepseek (line 1634) | def test_flash_attn_paged_deepseek(seqlen_q, page_size): function test_flash_attn_invalid_head_dim (line 1684) | def test_flash_attn_invalid_head_dim(head_dim): FILE: tests/cute/test_flash_attn_combine.py function attention_combine_ref (line 19) | def attention_combine_ref(out_partial, lse_partial): function check_combine_results (line 33) | def check_combine_results(out, lse, out_ref, lse_ref, dtype): function test_flash_attn_combine (line 58) | def test_flash_attn_combine(num_splits, seqlen, d, dtype): function test_flash_attn_combine_varlen (line 115) | def test_flash_attn_combine_varlen(varlen_mode, num_splits, seqlen, d, d... function test_flash_attn_combine_varlen_batch_idx (line 231) | def test_flash_attn_combine_varlen_batch_idx(num_splits, seqlen, d, dtype): FILE: tests/cute/test_flash_attn_fast.py function test_flash_attn_output (line 49) | def test_flash_attn_output(seqlen_q, seqlen_k, d, causal, num_splits, mh... function test_flash_attn_varlen_output (line 116) | def test_flash_attn_varlen_output(seqlen, d, causal, mha_type, dtype): function test_flash_attn_varlen_unpad_output (line 189) | def test_flash_attn_varlen_unpad_output(seqlen, d, causal, mha_type, unp... function attention_combine_ref (line 287) | def attention_combine_ref(out_partial, lse_partial): function test_flash_attn_combine (line 300) | def test_flash_attn_combine(num_splits, seqlen, d, dtype): FILE: tests/cute/test_flash_attn_race_condition.py function test_flash_attn_output (line 63) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 393) | def test_flash_attn_varlen_output( FILE: tests/cute/test_flash_attn_varlen.py function test_varlen (line 17) | def test_varlen( function check_varlen_vs_torch_flash (line 51) | def check_varlen_vs_torch_flash( function generate_varlen_args (line 147) | def generate_varlen_args( function torch_flash_ref (line 192) | def torch_flash_ref( function _stats (line 296) | def _stats(name, a, b, atol, rtol): FILE: tests/cute/test_mask_mod.py function reset_torch_state (line 38) | def reset_torch_state(): function create_tensors (line 48) | def create_tensors( function compute_reference_flex_attn (line 73) | def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tupl... function get_coarse_block_mask_pair (line 111) | def get_coarse_block_mask_pair(sparse_tile_m: int, tile_n: int, last_blo... function _run_mask_test (line 171) | def _run_mask_test( function test_mask_mod_ima_partial_block (line 480) | def test_mask_mod_ima_partial_block(): function test_q_boundary_masking_block_sparse_bwd (line 514) | def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_na... function test_single_doc_bwd_minimal (line 549) | def test_single_doc_bwd_minimal(): function test_static_masks (line 677) | def test_static_masks( function test_parameterized_masks (line 725) | def test_parameterized_masks( function test_sm100_block_sparse_sink_all_masked (line 758) | def test_sm100_block_sparse_sink_all_masked(): function test_sm100_block_sparse_q_stage1 (line 804) | def test_sm100_block_sparse_q_stage1(): function test_sm100_block_sparse_coarse_blocks (line 846) | def test_sm100_block_sparse_coarse_blocks(): function test_sm100_block_sparse_coarse_blocks_mismatch (line 943) | def test_sm100_block_sparse_coarse_blocks_mismatch(): function run_cute_mask_bwd (line 1053) | def run_cute_mask_bwd( function run_flex_reference_bwd (line 1090) | def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): function test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message (line 1127) | def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_mess... function test_gqa_block_sparse_broadcast_pattern_recompilation (line 1199) | def test_gqa_block_sparse_broadcast_pattern_recompilation(): function test_gqa_expand_stride_zero_bug (line 1301) | def test_gqa_expand_stride_zero_bug(): function test_persistent_blocksparse_empty_tiles (line 1416) | def test_persistent_blocksparse_empty_tiles(): FILE: tests/cute/test_score_mod.py function create_tensors (line 113) | def create_tensors( function run_cute_flash (line 122) | def run_cute_flash(q, k, v, cute_score_mod, aux_tensors=None, pack_gqa=F... function run_flex_reference (line 139) | def run_flex_reference(q, k, v, eager_score_mod, dtype=None) -> torch.Te... function test_cute_vs_flex_attention (line 149) | def test_cute_vs_flex_attention( function test_cute_score_mod_vectorized (line 202) | def test_cute_score_mod_vectorized( function test_cute_vs_flex_attention_with_aux_tensors (line 235) | def test_cute_vs_flex_attention_with_aux_tensors( function test_cute_score_mod_with_aux_tensors_vectorized (line 306) | def test_cute_score_mod_with_aux_tensors_vectorized( function _generate_block_kvcache (line 354) | def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d... function test_score_mod_with_paged_kvcache (line 396) | def test_score_mod_with_paged_kvcache( function test_score_mod_with_paged_kvcache_aux_tensors (line 545) | def test_score_mod_with_paged_kvcache_aux_tensors( function score_mod_bwd_5 (line 694) | def score_mod_bwd_5(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_inf... function score_mod_bwd_3 (line 700) | def score_mod_bwd_3(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_inf... function score_mod_bwd_identity (line 706) | def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seq... function score_mod_bwd_causal (line 711) | def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqle... function score_mod_squared (line 721) | def score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info... function score_mod_bwd_squared (line 727) | def score_mod_bwd_squared(grad, score, b_idx, h_idx, q_idx, kv_idx, seql... function score_squared_eager (line 732) | def score_squared_eager(score, b, h, q_idx, kv_idx): function run_cute_flash_bwd (line 754) | def run_cute_flash_bwd( function run_flex_reference_bwd (line 796) | def run_flex_reference_bwd(q, k, v, eager_score_mod, grad_out, dtype=None): function test_cute_vs_flex_attention_backward (line 832) | def test_cute_vs_flex_attention_backward(seqlen_q, seqlen_kv, dim, dtype... function make_aux_tensors_for_bwd (line 881) | def make_aux_tensors_for_bwd(cute_score_mod, eager_factory, seqlen_q, nu... function test_cute_vs_flex_attention_backward_with_aux (line 901) | def test_cute_vs_flex_attention_backward_with_aux( function test_cute_vs_flex_attention_backward_pack_gqa (line 962) | def test_cute_vs_flex_attention_backward_pack_gqa( FILE: tests/cute/test_score_mod_varlen.py function run_cute_flash (line 183) | def run_cute_flash( function run_flex_varlen_ref (line 232) | def run_flex_varlen_ref(q, k, v, cu_seqlens_q, cu_seqlens_k, score_mod, ... function setup_tensors (line 283) | def setup_tensors(seqlens_q, seqlens_k, varlen_q, varlen_k, num_heads, h... function prepare_ref_tensors (line 324) | def prepare_ref_tensors( function check_results (line 346) | def check_results( function test_varlen_with_score_mod (line 412) | def test_varlen_with_score_mod( function test_varlen_with_score_mod_vectorized (line 521) | def test_varlen_with_score_mod_vectorized( function test_varlen_with_global_idx_score_mod (line 602) | def test_varlen_with_global_idx_score_mod( function test_varlen_score_mod_kvcache (line 791) | def test_varlen_score_mod_kvcache( function test_varlen_score_mod_with_paged_kvcache_global (line 950) | def test_varlen_score_mod_with_paged_kvcache_global( FILE: tests/cute/test_utils.py class TestHashCallable (line 9) | class TestHashCallable: method test_returns_cute_hash_when_set_on_function (line 12) | def test_returns_cute_hash_when_set_on_function(self): method test_returns_cute_hash_from_wrapped_function (line 23) | def test_returns_cute_hash_from_wrapped_function(self): method test_prefers_wrapper_cute_hash_over_wrapped (line 39) | def test_prefers_wrapper_cute_hash_over_wrapped(self): method test_fallback_to_source_hashing (line 56) | def test_fallback_to_source_hashing(self): method test_same_function_produces_same_hash (line 67) | def test_same_function_produces_same_hash(self): method test_different_functions_produce_different_hashes (line 77) | def test_different_functions_produce_different_hashes(self): method test_fast_path_skips_expensive_hashing (line 90) | def test_fast_path_skips_expensive_hashing(self): method test_fast_path_on_wrapped_skips_expensive_hashing (line 125) | def test_fast_path_on_wrapped_skips_expensive_hashing(self): method test_closure_values_affect_hash (line 163) | def test_closure_values_affect_hash(self): class TestHashCallableIntegration (line 182) | class TestHashCallableIntegration: method test_repeated_calls_use_cached_hash (line 185) | def test_repeated_calls_use_cached_hash(self): FILE: tests/layers/test_rotary.py function test_rotary (line 21) | def test_rotary(rotary_emb_fraction, seqlen_offset): function test_rotary_interleaved (line 95) | def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset): FILE: tests/losses/test_cross_entropy.py function test_cross_entropy_loss (line 28) | def test_cross_entropy_loss( FILE: tests/losses/test_cross_entropy_parallel.py function test_cross_entropy_loss_parallel (line 32) | def test_cross_entropy_loss_parallel( FILE: tests/models/test_baichuan.py function test_baichuan_state_dict (line 36) | def test_baichuan_state_dict(model_name): function test_baichuan_optimized (line 60) | def test_baichuan_optimized(model_name): function test_baichuan_parallel_forward (line 144) | def test_baichuan_parallel_forward(model_name, world_size): function test_baichuan_generation (line 233) | def test_baichuan_generation(model_name): function test_baichuan_parallel_generation (line 345) | def test_baichuan_parallel_generation(model_name, world_size): FILE: tests/models/test_bert.py function test_bert_state_dict (line 23) | def test_bert_state_dict(model_name): function get_hf_models (line 33) | def get_hf_models(model_name, config, dtype): function test_bert_non_optimized (line 53) | def test_bert_non_optimized(model_name): function test_bert_optimized (line 100) | def test_bert_optimized(model_name): function test_bert_dense_seq_output (line 207) | def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_la... function test_inv_remap_state_dict (line 309) | def test_inv_remap_state_dict(model_name: str): FILE: tests/models/test_bigcode.py function test_bigcode_state_dict (line 15) | def test_bigcode_state_dict(model_name): function test_bigcode_optimized (line 28) | def test_bigcode_optimized(model_name): function test_bigcode_generation (line 88) | def test_bigcode_generation(model_name): function test_inv_remap_state_dict (line 189) | def test_inv_remap_state_dict(model_name: str): FILE: tests/models/test_btlm.py function test_btlm_state_dict (line 16) | def test_btlm_state_dict(model_name): function test_btlm_optimized (line 30) | def test_btlm_optimized(model_name): function test_btlm_generation (line 100) | def test_btlm_generation(model_name): function test_btlm_init (line 206) | def test_btlm_init(model_name): FILE: tests/models/test_falcon.py function test_falcon_state_dict (line 21) | def test_falcon_state_dict(model_name): function test_falcon_optimized (line 36) | def test_falcon_optimized(model_name): function test_falcon_parallel_forward (line 104) | def test_falcon_parallel_forward(model_name, world_size): function test_falcon_generation (line 186) | def test_falcon_generation(model_name): function test_falcon_parallel_generation (line 294) | def test_falcon_parallel_generation(model_name, world_size): FILE: tests/models/test_gpt.py function test_gpt2_state_dict (line 20) | def test_gpt2_state_dict(model_name): function test_gpt2_non_optimized (line 32) | def test_gpt2_non_optimized(model_name): function test_gpt2_optimized (line 82) | def test_gpt2_optimized(model_name): function test_gpt2_generation (line 142) | def test_gpt2_generation(model_name, rotary, optimized): function get_logits (line 264) | def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwa... function test_gpt2_generation_cg (line 282) | def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): function test_gpt2_multiple_token_generation (line 345) | def test_gpt2_multiple_token_generation(model_name, optimized): function test_gpt2_speculative_decoding (line 391) | def test_gpt2_speculative_decoding(model_name, optimized, cg): function test_gpt2_shard_unshard (line 460) | def test_gpt2_shard_unshard(n_heads_q_kv): FILE: tests/models/test_gpt_generation_parallel.py function test_tensor_parallel (line 21) | def test_tensor_parallel(model_name, rotary, world_size): FILE: tests/models/test_gpt_neox.py function test_gptj_state_dict (line 15) | def test_gptj_state_dict(model_name): function test_gpt_neox_optimized (line 36) | def test_gpt_neox_optimized(model_name): FILE: tests/models/test_gpt_parallel.py function test_gpt_parallel (line 29) | def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, d... FILE: tests/models/test_gptj.py function test_gptj_state_dict (line 16) | def test_gptj_state_dict(model_name): function test_gptj_optimized (line 27) | def test_gptj_optimized(model_name): function test_gptj_generation (line 87) | def test_gptj_generation(model_name): FILE: tests/models/test_llama.py function _pretrained_state_dict_from_checkpoint (line 36) | def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, ... function test_llama_state_dict (line 50) | def test_llama_state_dict(model_name): function test_inv_remap_state_dict_hf_llama (line 68) | def test_inv_remap_state_dict_hf_llama(model_name): function test_llama_optimized (line 95) | def test_llama_optimized(model_name): function test_llama_parallel (line 186) | def test_llama_parallel(model_name, world_size): function test_llama_generation (line 289) | def test_llama_generation(model_name, checkpoint_format): function test_llama_parallel_generation (line 402) | def test_llama_parallel_generation(model_name, world_size): function test_llama_parallel_uneven_num_heads (line 537) | def test_llama_parallel_uneven_num_heads(world_size): FILE: tests/models/test_opt.py function test_opt_state_dict (line 19) | def test_opt_state_dict(model_name): function test_opt_optimized (line 33) | def test_opt_optimized(model_name): function test_opt_generation (line 100) | def test_opt_generation(model_name): FILE: tests/models/test_vit.py function test_vit (line 13) | def test_vit(optimized, fused_mlp): FILE: tests/modules/test_block_parallel.py function test_block_parallel (line 28) | def test_block_parallel(dim, sequence_parallel, world_size, dtype): FILE: tests/modules/test_embedding_parallel.py function test_embedding_parallel (line 24) | def test_embedding_parallel(dim, has_pos_emb, sequence_parallel, world_s... FILE: tests/modules/test_mha_parallel.py function test_mha_parallel (line 26) | def test_mha_parallel(embed_dim, head_dim, sequence_parallel, world_size... FILE: tests/modules/test_mlp_parallel.py function test_mlp_parallel (line 24) | def test_mlp_parallel(dim, activation, sequence_parallel, world_size, dt... FILE: tests/ops/test_dropout_layer_norm.py function test_dropout_layer_norm_training (line 52) | def test_dropout_layer_norm_training( function test_dropout_layer_norm_eval (line 177) | def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtyp... function test_dropout_layer_norm_prenorm_training (line 239) | def test_dropout_layer_norm_prenorm_training( function test_dropout_layer_norm_prenorm_eval (line 371) | def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, resid... function test_dropout_layer_norm_subset_training (line 435) | def test_dropout_layer_norm_subset_training( function test_dropout_layer_norm_subset_prenorm_training (line 592) | def test_dropout_layer_norm_subset_prenorm_training( function test_dropout_layer_norm_parallel_residual_training (line 762) | def test_dropout_layer_norm_parallel_residual_training( function test_dropout_layer_norm_parallel_residual_prenorm_training (line 971) | def test_dropout_layer_norm_parallel_residual_prenorm_training( function test_dropout_layer_norm_randomness (line 1161) | def test_dropout_layer_norm_randomness(): FILE: tests/ops/test_fused_dense.py function test_fused_linear_bias (line 16) | def test_fused_linear_bias(in_features, out_features, has_bias, return_r... function test_fused_mlp (line 92) | def test_fused_mlp( FILE: tests/ops/test_fused_dense_parallel.py function test_fused_linear_bias (line 25) | def test_fused_linear_bias( function test_fused_mlp (line 124) | def test_fused_mlp(in_features, out_features, has_bias2, sequence_parall... FILE: tests/ops/triton/test_layer_norm.py function test_layer_norm (line 47) | def test_layer_norm( function test_layer_norm_linear (line 263) | def test_layer_norm_linear( FILE: tests/test_flash_attn.py function attn_bias_from_alibi_slopes (line 29) | def attn_bias_from_alibi_slopes( function generate_random_padding_mask (line 58) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 74) | def generate_qkv( function construct_local_mask (line 182) | def construct_local_mask( function attention_ref (line 217) | def attention_ref( function attention_kvpacked_ref (line 307) | def attention_kvpacked_ref( function attention_qkvpacked_ref (line 340) | def attention_qkvpacked_ref( function generate_sparsity_mask (line 369) | def generate_sparsity_mask(seqlen, sparsity=0.3): function attention_blocksparse_ref (line 382) | def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, drop... function convert_flash_attn_S_to_softmax (line 411) | def convert_flash_attn_S_to_softmax( function normalize_flash_attn_S (line 465) | def normalize_flash_attn_S( function get_dropout_fraction (line 529) | def get_dropout_fraction( function test_flash_attn_qkvpacked (line 586) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi... function test_flash_attn_varlen_qkvpacked (line 733) | def test_flash_attn_varlen_qkvpacked( function test_flash_attn_output (line 903) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 1172) | def test_flash_attn_varlen_output( function test_flash_attn_causal (line 1482) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty... function test_flash_attn_varlen_causal (line 1593) | def test_flash_attn_varlen_causal( function test_flash_attn_splitkv (line 1765) | def test_flash_attn_splitkv( function test_flash_attn_kvcache (line 1907) | def test_flash_attn_kvcache( function _generate_block_kvcache (line 2143) | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, n... function test_flash_attn_race_condition (line 2199) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau... function test_flash_attn_bwd_overflow (line 2247) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): function test_flash_attn_bwd_transpose (line 2303) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): function test_flash_attn_bwd_varlen_overflow (line 2355) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): function test_flash_attn_deterministic (line 2413) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau... function test_flash_attn_varlen_deterministic (line 2471) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,... FILE: tests/test_flash_attn_ck.py function is_bwd_hdim_supported (line 30) | def is_bwd_hdim_supported(d): function ck_randval_to_dropout_mask (line 34) | def ck_randval_to_dropout_mask(randval, p): function pad_rearrange_dropout_mask_hts_to_bhss (line 41) | def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen... function test_flash_attn_qkvpacked (line 73) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi... function test_flash_attn_varlen_qkvpacked (line 171) | def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local... function test_flash_attn_output (line 305) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 522) | def test_flash_attn_varlen_output( function test_flash_attn_causal (line 780) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty... function test_flash_attn_varlen_causal (line 880) | def test_flash_attn_varlen_causal( function test_flash_attn_kvcache (line 1053) | def test_flash_attn_kvcache( function test_flash_attn_race_condition (line 1314) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau... function test_flash_attn_bwd_overflow (line 1360) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): function test_flash_attn_bwd_transpose (line 1417) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): function test_flash_attn_bwd_varlen_overflow (line 1467) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): function test_flash_attn_deterministic (line 1515) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau... function test_flash_attn_varlen_deterministic (line 1563) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,... FILE: tests/test_flash_attn_triton_amd.py function _get_block_size_n_triton (line 22) | def _get_block_size_n_triton(device, head_dim, is_dropout, is_causal): function attn_bias_from_alibi_slopes (line 44) | def attn_bias_from_alibi_slopes( function generate_random_padding_mask (line 73) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 89) | def generate_qkv( function construct_local_mask (line 197) | def construct_local_mask( function attention_ref (line 232) | def attention_ref( function attention_kvpacked_ref (line 322) | def attention_kvpacked_ref( function attention_qkvpacked_ref (line 355) | def attention_qkvpacked_ref( function generate_sparsity_mask (line 384) | def generate_sparsity_mask(seqlen, sparsity=0.3): function attention_blocksparse_ref (line 397) | def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, drop... function convert_flash_attn_S_to_softmax (line 426) | def convert_flash_attn_S_to_softmax( function normalize_flash_attn_S (line 480) | def normalize_flash_attn_S( function get_dropout_fraction (line 544) | def get_dropout_fraction( function test_flash_attn_qkvpacked (line 601) | def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi... function test_flash_attn_varlen_qkvpacked (line 748) | def test_flash_attn_varlen_qkvpacked( function test_flash_attn_output (line 918) | def test_flash_attn_output( function test_flash_attn_varlen_output (line 1191) | def test_flash_attn_varlen_output( function test_flash_attn_causal (line 1504) | def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dty... function test_flash_attn_varlen_causal (line 1619) | def test_flash_attn_varlen_causal( function test_flash_attn_splitkv (line 1792) | def test_flash_attn_splitkv( function test_flash_attn_kvcache (line 1937) | def test_flash_attn_kvcache( function _generate_block_kvcache (line 2173) | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, n... function test_flash_attn_race_condition (line 2230) | def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, cau... function test_flash_attn_bwd_overflow (line 2279) | def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): function test_flash_attn_bwd_transpose (line 2336) | def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): function test_flash_attn_bwd_varlen_overflow (line 2389) | def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): function test_flash_attn_deterministic (line 2448) | def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, cau... function test_flash_attn_varlen_deterministic (line 2507) | def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk,... FILE: tests/test_rotary.py function generate_cos_sin (line 18) | def generate_cos_sin(seqlen, rotary_dim, device, dtype): function generate_seqlen_offsets (line 26) | def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, dev... function index_cos_sin (line 35) | def index_cos_sin(cos, sin, seqlen_offsets, seqlen): function test_rotary_emb_func (line 60) | def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_o... function test_rotary_emb_qkv (line 113) | def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_typ... function test_rotary_emb_kv (line 181) | def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type... function test_rotary_emb_varlen_func (line 229) | def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, s... function test_compilation_count (line 281) | def test_compilation_count(): FILE: tests/test_util.py function generate_random_padding_mask (line 8) | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="r... function generate_qkv (line 31) | def generate_qkv( function construct_local_mask (line 150) | def construct_local_mask( function attention_ref (line 185) | def attention_ref( FILE: tools/sass_diff.py class Line (line 35) | class Line: function _normalize_instr (line 44) | def _normalize_instr(text: str) -> str: function parse_sass (line 68) | def parse_sass(path: str) -> list[Line]: class DiffBlock (line 95) | class DiffBlock: function diff_sass (line 101) | def diff_sass(a_lines: list[Line], b_lines: list[Line]) -> list[DiffBlock]: function _fmt (line 120) | def _fmt(line: Line, prefix: str, color: str, use_color: bool, show_norm... function print_diff (line 128) | def print_diff(blocks: list[DiffBlock], context: int = 3, function _get_opcode (line 174) | def _get_opcode(raw: str) -> str | None: function print_summary (line 183) | def print_summary(a_all: list[Line], b_all: list[Line], blocks: list[Dif... function main (line 221) | def main(): FILE: training/run.py function dictconfig_filter_key (line 23) | def dictconfig_filter_key(d: DictConfig, fn: Callable) -> DictConfig: function main (line 34) | def main(config: DictConfig): FILE: training/src/callbacks/causality_monitor.py class CausalityMonitor (line 9) | class CausalityMonitor(Callback): method __init__ (line 26) | def __init__(self, seq_len: int = 10, input_dim: int = 0): method on_train_epoch_end (line 32) | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lig... FILE: training/src/callbacks/ema.py class EMACallback (line 16) | class EMACallback(Callback): method __init__ (line 19) | def __init__(self, decay: float, use_num_updates: bool = True): method on_train_start (line 30) | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightni... method on_train_batch_end (line 40) | def on_train_batch_end( method on_validation_start (line 51) | def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.Li... method on_validation_end (line 57) | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.Ligh... method on_test_start (line 61) | def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin... method on_test_end (line 66) | def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningM... method on_save_checkpoint (line 70) | def on_save_checkpoint( method on_load_checkpoint (line 75) | def on_load_checkpoint( FILE: training/src/callbacks/flop_count.py class FlopCount (line 14) | class FlopCount(Callback): method __init__ (line 17) | def __init__(self, profilers: List[str] = ['fvcore', 'deepspeed'], method on_fit_start (line 34) | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -... FILE: training/src/callbacks/gpu_affinity.py function l2_promote (line 10) | def l2_promote(): function set_affinity (line 21) | def set_affinity(trainer): class GpuAffinity (line 34) | class GpuAffinity(Callback): method setup (line 39) | def setup(self, trainer: Trainer, pl_module: LightningModule, stage=No... FILE: training/src/callbacks/loss_scale_monitor.py class LossScaleMonitor (line 9) | class LossScaleMonitor(Callback): method on_before_optimizer_step (line 17) | def on_before_optimizer_step(self, trainer: Trainer, *args: Any, **kwa... FILE: training/src/callbacks/model_checkpoint.py class ModelCheckpointMine (line 8) | class ModelCheckpointMine(pl.callbacks.model_checkpoint.ModelCheckpoint): method __init__ (line 10) | def __init__(self, *args, fault_tolerant=False, **kwargs): method on_exception (line 14) | def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> N... FILE: training/src/callbacks/norm_monitor.py class NormMonitor (line 22) | class NormMonitor(Callback): method __init__ (line 26) | def __init__(self, layer_norm_only: bool = False): method on_before_optimizer_step (line 33) | def on_before_optimizer_step(self, trainer: Trainer, pl_module, *args:... FILE: training/src/callbacks/params_log.py class ParamsLog (line 8) | class ParamsLog(Callback): method __init__ (line 11) | def __init__(self, total_params_log: bool = True, trainable_params_log... method on_fit_start (line 23) | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -... FILE: training/src/callbacks/speed_monitor.py class SpeedMonitor (line 12) | class SpeedMonitor(Callback): method __init__ (line 15) | def __init__(self, intra_step_time: bool = True, inter_step_time: bool... method on_train_start (line 27) | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightni... method on_train_epoch_start (line 30) | def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.L... method on_validation_epoch_start (line 35) | def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: ... method on_test_epoch_start (line 38) | def on_test_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Li... method on_train_batch_start (line 42) | def on_train_batch_start( method on_train_batch_end (line 64) | def on_train_batch_end( method on_train_epoch_end (line 89) | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lig... FILE: training/src/callbacks/wandb_callbacks.py function get_wandb_logger (line 16) | def get_wandb_logger(trainer: Trainer) -> WandbLogger: class WatchModel (line 37) | class WatchModel(Callback): method __init__ (line 40) | def __init__(self, log: str = "gradients", log_freq: int = 100): method on_train_start (line 45) | def on_train_start(self, trainer, pl_module): class UploadCodeAsArtifact (line 50) | class UploadCodeAsArtifact(Callback): method __init__ (line 53) | def __init__(self, code_dir: str, use_git: bool = True): method on_train_start (line 65) | def on_train_start(self, trainer, pl_module): class UploadCheckpointsAsArtifact (line 97) | class UploadCheckpointsAsArtifact(Callback): method __init__ (line 100) | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: b... method on_keyboard_interrupt (line 105) | def on_keyboard_interrupt(self, trainer, pl_module): method on_train_end (line 109) | def on_train_end(self, trainer, pl_module): class LogConfusionMatrix (line 124) | class LogConfusionMatrix(Callback): method __init__ (line 129) | def __init__(self): method on_sanity_check_start (line 134) | def on_sanity_check_start(self, trainer, pl_module) -> None: method on_sanity_check_end (line 137) | def on_sanity_check_end(self, trainer, pl_module): method on_validation_batch_end (line 141) | def on_validation_batch_end( method on_validation_epoch_end (line 149) | def on_validation_epoch_end(self, trainer, pl_module): class LogF1PrecRecHeatmap (line 182) | class LogF1PrecRecHeatmap(Callback): method __init__ (line 187) | def __init__(self, class_names: List[str] = None): method on_sanity_check_start (line 192) | def on_sanity_check_start(self, trainer, pl_module): method on_sanity_check_end (line 195) | def on_sanity_check_end(self, trainer, pl_module): method on_validation_batch_end (line 199) | def on_validation_batch_end( method on_validation_epoch_end (line 207) | def on_validation_epoch_end(self, trainer, pl_module): class LogImagePredictions (line 245) | class LogImagePredictions(Callback): method __init__ (line 251) | def __init__(self, num_samples: int = 8): method on_sanity_check_start (line 256) | def on_sanity_check_start(self, trainer, pl_module): method on_sanity_check_end (line 259) | def on_sanity_check_end(self, trainer, pl_module): method on_validation_epoch_end (line 263) | def on_validation_epoch_end(self, trainer, pl_module): FILE: training/src/datamodules/datasets/detokenizer.py function wikitext_detokenize (line 10) | def wikitext_detokenize(string: str) -> str: FILE: training/src/datamodules/datasets/lm_dataset.py class LMDataset (line 10) | class LMDataset(torch.utils.data.Dataset): method __init__ (line 12) | def __init__(self, tokens, seq_len, drop_last=True): method __len__ (line 25) | def __len__(self): method __getitem__ (line 28) | def __getitem__(self, idx): FILE: training/src/datamodules/fault_tolerant_sampler.py class RandomFaultTolerantSampler (line 9) | class RandomFaultTolerantSampler(RandomSampler): method __init__ (line 11) | def __init__(self, *args, generator=None, **kwargs): method state_dict (line 26) | def state_dict(self): method load_state_dict (line 29) | def load_state_dict(self, state_dict): method __iter__ (line 43) | def __iter__(self) -> Iterator[int]: class FaultTolerantDistributedSampler (line 64) | class FaultTolerantDistributedSampler(DistributedSampler): method __init__ (line 66) | def __init__(self, *args, **kwargs): method state_dict (line 72) | def state_dict(self): method load_state_dict (line 75) | def load_state_dict(self, state_dict): method __iter__ (line 86) | def __iter__(self): FILE: training/src/datamodules/imagenet.py class DictDataset (line 17) | class DictDataset(Dataset): method __init__ (line 19) | def __init__(self, dataset_dict, length=None): method __getitem__ (line 28) | def __getitem__(self, index): method __len__ (line 31) | def __len__(self): function imagenet_normalization (line 36) | def imagenet_normalization(): class ImagenetDataModule (line 40) | class ImagenetDataModule(LightningDataModule): method __init__ (line 63) | def __init__( method num_classes (line 116) | def num_classes(self) -> int: method _verify_splits (line 123) | def _verify_splits(self, data_dir: str, split: str) -> None: method prepare_data (line 132) | def prepare_data(self) -> None: method setup (line 139) | def setup(self, stage: Optional[str] = None) -> None: method train_transform (line 164) | def train_transform(self) -> Callable: method val_transform (line 188) | def val_transform(self) -> Callable: method train_dataloader (line 212) | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: method val_dataloader (line 224) | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoade... method test_dataloader (line 250) | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoad... method _data_loader (line 256) | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: boo... class Imagenet21kPDataModule (line 273) | class Imagenet21kPDataModule(ImagenetDataModule): method num_classes (line 278) | def num_classes(self) -> int: FILE: training/src/datamodules/language_modeling_hf.py class SHMArray (line 29) | class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/us... method __new__ (line 31) | def __new__(cls, input_array, shm=None): method __array_finalize__ (line 36) | def __array_finalize__(self, obj): class LMDataModule (line 41) | class LMDataModule(LightningDataModule): method __init__ (line 42) | def __init__(self, dataset_name, tokenizer_name, dataset_config_name=N... method prepare_data (line 80) | def prepare_data(self): method setup (line 86) | def setup(self, stage=None): method process_dataset (line 97) | def process_dataset(self): method _save_to_cache (line 232) | def _save_to_cache(self, concat_ids, tokenizer, cache_dir): method _load_from_cache (line 240) | def _load_from_cache(self, cache_dir): method _cache_dir_name (line 250) | def _cache_dir_name(self): method train_dataloader (line 253) | def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: method val_dataloader (line 272) | def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoade... method test_dataloader (line 276) | def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoad... method _data_loader (line 280) | def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: boo... method load_state_dict (line 293) | def load_state_dict(self, checkpoint): FILE: training/src/datamodules/timm_mixup.py class TimmMixup (line 7) | class TimmMixup(Mixup): method __call__ (line 10) | def __call__(self, x, target): FILE: training/src/distributed/ddp_comm_hooks.py function fp16_compress_hook (line 9) | def fp16_compress_hook( FILE: training/src/eval.py function remove_prefix (line 22) | def remove_prefix(text: str, prefix: str): function load_checkpoint (line 28) | def load_checkpoint(path, device='cpu'): function evaluate (line 47) | def evaluate(config: DictConfig) -> None: FILE: training/src/metrics/accuracy.py class AccuracyMine (line 7) | class AccuracyMine(Accuracy): method update (line 10) | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore FILE: training/src/metrics/num_tokens.py class NumTokens (line 9) | class NumTokens(Metric): method __init__ (line 22) | def __init__(self, **kwargs: Dict[str, Any]): method update (line 27) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]... method compute (line 30) | def compute(self) -> Tensor: method reset (line 33) | def reset(self): method _forward_reduce_state_update (line 39) | def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: FILE: training/src/metrics/perplexity.py class Perplexity (line 21) | class Perplexity(Metric): method __init__ (line 43) | def __init__(self, **kwargs: Dict[str, Any]): method update (line 51) | def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor]... method compute (line 65) | def compute(self) -> Tensor: FILE: training/src/models/modules/seq_common.py function pooling (line 15) | def pooling(x, pooling_mode='CLS', key_padding_mask=None, batch_first=Tr... class ClassificationHeadLinear (line 49) | class ClassificationHeadLinear(nn.Module): method __init__ (line 52) | def __init__(self, d_model, num_classes, pooling_mode='MEAN', method forward (line 60) | def forward(self, hidden_states, key_padding_mask=None, **kwargs): class ClassificationHead (line 71) | class ClassificationHead(nn.Module): method __init__ (line 74) | def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling... method forward (line 84) | def forward(self, hidden_states, key_padding_mask=None, **kwargs): class ClassificationHeadDual (line 99) | class ClassificationHeadDual(nn.Module): method __init__ (line 102) | def __init__(self, d_model, d_inner, num_classes, dropout=0.0, pooling... method forward (line 114) | def forward(self, hidden_states1, hidden_states2, class LMHead (line 134) | class LMHead(nn.Module): method __init__ (line 136) | def __init__(self, d_model, num_classes, batch_first=True, bias=True): method forward (line 140) | def forward(self, hidden_states, **kwargs): function sinusoidal_init_ (line 148) | def sinusoidal_init_(tensor): class PositionalEncoding (line 161) | class PositionalEncoding(nn.Module): method __init__ (line 178) | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=Fal... method forward (line 192) | def forward(self, x): class Mlp (line 207) | class Mlp(nn.Module): method __init__ (line 210) | def __init__(self, in_features, hidden_features=None, out_features=Non... method forward (line 227) | def forward(self, x): class MlpBig (line 236) | class MlpBig(nn.Module): method __init__ (line 239) | def __init__(self, in_features, hidden_features=None, out_features=Non... method forward (line 259) | def forward(self, x): class GluMlp (line 262) | class GluMlp(nn.Module): method __init__ (line 266) | def __init__(self, in_features, hidden_features=None, out_features=Non... method init_weights (line 276) | def init_weights(self): method forward (line 282) | def forward(self, x): class GatedMlp (line 292) | class GatedMlp(nn.Module): method __init__ (line 295) | def __init__(self, in_features, hidden_features=None, out_features=Non... method forward (line 311) | def forward(self, x): class ConvMlp (line 321) | class ConvMlp(nn.Module): method __init__ (line 324) | def __init__( method forward (line 335) | def forward(self, x): FILE: training/src/optim/param_grouping.py function group_parameters_for_optimizer (line 15) | def group_parameters_for_optimizer(model, optimizer_cfg, bias_weight_dec... FILE: training/src/optim/timm_lr_scheduler.py class TimmCosineLRScheduler (line 8) | class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler.... method __init__ (line 13) | def __init__(self, *args, **kwargs): method step (line 18) | def step(self, epoch=None): FILE: training/src/tasks/seq.py class SequenceModel (line 20) | class SequenceModel(LightningModule): method __init__ (line 22) | def __init__(self, cfg, model_cfg=None): method instantiate_datamodule (line 38) | def instantiate_datamodule(self): method instantiate_model (line 47) | def instantiate_model(self): method instantiate_loss (line 58) | def instantiate_loss(self): method instantiate_metrics (line 66) | def instantiate_metrics(self): method warmstart (line 79) | def warmstart(self): method forward (line 90) | def forward(self, *args, **kwargs): method step (line 93) | def step(self, batch: Any, is_train=True): method shared_step (line 103) | def shared_step(self, batch: Any, batch_idx: int, phase='train'): method training_step (line 117) | def training_step(self, batch: Any, batch_idx: int): method validation_step (line 120) | def validation_step(self, batch: Any, batch_idx: int): method test_step (line 123) | def test_step(self, batch: Any, batch_idx: int): method configure_optimizers (line 126) | def configure_optimizers(self): method optimizer_zero_grad (line 151) | def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_i... method on_save_checkpoint (line 159) | def on_save_checkpoint(self, checkpoint): class SequenceLMModel (line 169) | class SequenceLMModel(SequenceModel): method step (line 171) | def step(self, batch: Any, is_train=True): method shared_step (line 179) | def shared_step(self, batch: Any, batch_idx: int, phase='train'): FILE: training/src/train.py function last_modification_time (line 20) | def last_modification_time(path): function train (line 32) | def train(config: DictConfig) -> Optional[float]: FILE: training/src/utils/checkpoint.py function load_checkpoint (line 8) | def load_checkpoint(path, device='cpu'): function blockdiag_to_dense_mlp_bert (line 32) | def blockdiag_to_dense_mlp_bert(state_dict): function interpolate_pos_embedding (line 41) | def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name... function remove_model_prefix (line 68) | def remove_model_prefix(state_dict): FILE: training/src/utils/ddp_zero1.py function get_zero_optimizer_state_dict_local (line 24) | def get_zero_optimizer_state_dict_local(optimizer, global_rank): class DDPStrategyZero1 (line 62) | class DDPStrategyZero1(DDPStrategy): method optimizer_state (line 69) | def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: method save_checkpoint (line 77) | def save_checkpoint( method load_checkpoint (line 96) | def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: FILE: training/src/utils/ddp_zero2.py class DistAdamNativeMixedPrecisionPlugin (line 26) | class DistAdamNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): method optimizer_step (line 28) | def optimizer_step( # type: ignore[override] method clip_grad_by_norm (line 64) | def clip_grad_by_norm(self, optimizer: DistributedFusedAdam, clip_val:... class DDPStrategyZero2 (line 73) | class DDPStrategyZero2(DDPStrategy): method __init__ (line 80) | def __init__( method precision_plugin (line 92) | def precision_plugin(self) -> PrecisionPlugin: method precision_plugin (line 96) | def precision_plugin(self, precision_plugin: Optional[PrecisionPlugin]... method optimizer_state (line 106) | def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: method save_checkpoint (line 114) | def save_checkpoint( method load_checkpoint (line 133) | def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: FILE: training/src/utils/distributed.py function init_distributed (line 23) | def init_distributed(cuda): function barrier (line 39) | def barrier(): function get_rank (line 47) | def get_rank(): function get_world_size (line 58) | def get_world_size(): function all_reduce_item (line 70) | def all_reduce_item(value, op='sum'): function sync_workers (line 105) | def sync_workers(): FILE: training/src/utils/ema.py function to_float_maybe (line 13) | def to_float_maybe(x): class ExponentialMovingAverage (line 19) | class ExponentialMovingAverage: method __init__ (line 29) | def __init__( method _get_parameters (line 50) | def _get_parameters( method update (line 76) | def update( method copy_to (line 106) | def copy_to( method store (line 123) | def store( method restore (line 141) | def restore( method average_parameters (line 168) | def average_parameters( method to (line 195) | def to(self, device=None, dtype=None) -> None: method state_dict (line 216) | def state_dict(self) -> dict: method load_state_dict (line 228) | def load_state_dict(self, state_dict: dict) -> None: FILE: training/src/utils/flops.py function profile_deepspeed (line 20) | def profile_deepspeed(model, input_size=(3, 224, 224), input_dtype=torch... function profile_fvcore (line 35) | def profile_fvcore(model, input_size=(3, 224, 224), input_dtype=torch.fl... FILE: training/src/utils/gpu_affinity.py function systemGetDriverVersion (line 12) | def systemGetDriverVersion(): function deviceGetCount (line 16) | def deviceGetCount(): class device (line 20) | class device: method __init__ (line 24) | def __init__(self, device_idx): method getName (line 28) | def getName(self): method getCpuAffinity (line 31) | def getCpuAffinity(self): function set_socket_affinity (line 45) | def set_socket_affinity(gpu_id): function set_single_affinity (line 51) | def set_single_affinity(gpu_id): function set_single_unique_affinity (line 57) | def set_single_unique_affinity(gpu_id, nproc_per_node): function set_socket_unique_affinity (line 80) | def set_socket_unique_affinity(gpu_id, nproc_per_node, mode): function get_thread_siblings_list (line 113) | def get_thread_siblings_list(): function set_affinity (line 127) | def set_affinity(gpu_id, nproc_per_node, mode='socket'): FILE: training/src/utils/utils.py class LoggingContext (line 13) | class LoggingContext: method __init__ (line 14) | def __init__(self, logger, level=None, handler=None, close=True): method __enter__ (line 20) | def __enter__(self): method __exit__ (line 27) | def __exit__(self, et, ev, tb): function get_logger (line 37) | def get_logger(name=__name__) -> logging.Logger: function extras (line 50) | def extras(config: DictConfig) -> None: function print_config (line 89) | def print_config( function finish (line 131) | def finish( FILE: training/tests/datamodules/test_language_modeling_hf.py function div_up (line 19) | def div_up(x: int, y: int) -> int: function num_cpu_cores (line 24) | def num_cpu_cores(): class TestLMDataModule (line 32) | class TestLMDataModule: method test_wikitext2 (line 34) | def test_wikitext2(self): method test_wikitext103 (line 64) | def test_wikitext103(self): method test_openwebtext (line 94) | def test_openwebtext(self): method test_lambada (line 125) | def test_lambada(self): method test_the_pile (line 156) | def test_the_pile(self): method test_pg19 (line 188) | def test_pg19(self):