SYMBOL INDEX (465 symbols across 55 files) FILE: csrc/selective_scan/selective_scan.cpp function set_ssm_params_fwd (line 59) | void set_ssm_params_fwd(SSMParamsBase ¶ms, function set_ssm_params_bwd (line 143) | void set_ssm_params_bwd(SSMParamsBwd ¶ms, function selective_scan_fwd (line 226) | std::vector function selective_scan_bwd (line 338) | std::vector function PYBIND11_MODULE (line 494) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { FILE: csrc/selective_scan/selective_scan.h type SSMScanParamsBase (line 9) | struct SSMScanParamsBase { type SSMParamsBase (line 26) | struct SSMParamsBase { function SSMParamsBase (line 71) | struct SSMParamsBwd: public SSMParamsBase { FILE: csrc/selective_scan/selective_scan_common.h function custom_max (line 18) | constexpr size_t custom_max(std::initializer_list ilist) function T (line 24) | T constexpr_min(T a, T b) { function custom_max (line 29) | constexpr size_t custom_max(std::initializer_list ilist) function T (line 35) | T constexpr_min(T a, T b) { type BytesToType (line 61) | struct BytesToType type BytesToType (line 66) | struct BytesToType type BytesToType (line 71) | struct BytesToType type BytesToType (line 76) | struct BytesToType type BytesToType (line 81) | struct BytesToType function __device__ (line 90) | static inline __device__ void to_float(const scalar_t (&src)[N], float (... function __device__ (line 98) | static inline __device__ void to_float(const at::Half (&src)[N], float (... function __device__ (line 110) | static inline __device__ void to_float(const at::BFloat16 (&src)[N], flo... function complex_t (line 124) | complex_t cexp2f(complex_t z) { function complex_t (line 131) | complex_t cexpf(complex_t z) { function float (line 141) | struct SSMScanOp { function complex_t (line 148) | struct SSMScanOp { function __device__ (line 166) | __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_pre... function __device__ (line 169) | __device__ scan_t operator()(scan_t block_aggregate) { function typename (line 186) | typename Ktraits::BlockLoadVecT(smem_load_vec).Load( function typename (line 210) | typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( function typename (line 225) | typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( function typename (line 248) | typename Ktraits::BlockStoreVecT(smem_store_vec).Store( FILE: evals/lm_harness_eval.py class MambaEvalWrapper (line 15) | class MambaEvalWrapper(HFLM): method __init__ (line 19) | def __init__(self, pretrained="state-spaces/mamba-2.8b", max_length=20... method batch_size (line 31) | def batch_size(self): method _model_generate (line 34) | def _model_generate(self, context, max_length, stop, **generation_kwar... FILE: mamba_ssm/distributed/distributed_utils.py function all_gather_raw (line 18) | def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op... function reduce_scatter_raw (line 30) | def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, asyn... function all_reduce_raw (line 43) | def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op... class AllGatherFunc (line 49) | class AllGatherFunc(torch.autograd.Function): method forward (line 53) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 59) | def backward(ctx, grad_output: Tensor): class ReduceScatterFunc (line 68) | class ReduceScatterFunc(torch.autograd.Function): method forward (line 72) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 78) | def backward(ctx, grad_output: Tensor): class AllReduceFunc (line 87) | class AllReduceFunc(torch.autograd.Function): method forward (line 91) | def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: method backward (line 97) | def backward(ctx, grad_output: Tensor): function sync_shared_params (line 105) | def sync_shared_params(model: torch.nn.Module, process_group: ProcessGro... function allreduce_sequence_parallel_grad (line 120) | def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_gro... function get_dim_for_local_rank (line 135) | def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, m... FILE: mamba_ssm/distributed/tensor_parallel.py class ParallelLinearFunc (line 23) | class ParallelLinearFunc(torch.autograd.Function): method forward (line 26) | def forward(ctx, x, weight, bias, process_group=None, sequence_paralle... method backward (line 62) | def backward(ctx, grad_output): function parallel_linear_func (line 101) | def parallel_linear_func( class ColumnParallelLinear (line 111) | class ColumnParallelLinear(nn.Linear): method __init__ (line 112) | def __init__( method forward (line 138) | def forward(self, x): class RowParallelLinear (line 151) | class RowParallelLinear(nn.Linear): method __init__ (line 152) | def __init__( method forward (line 184) | def forward(self, x): class VocabParallelEmbedding (line 194) | class VocabParallelEmbedding(nn.Embedding): method __init__ (line 195) | def __init__(self, num_embeddings, *args, process_group=None, padding_... method forward (line 210) | def forward(self, input: Tensor) -> Tensor: class ColumnParallelEmbedding (line 226) | class ColumnParallelEmbedding(nn.Embedding): method __init__ (line 227) | def __init__(self, num_embeddings, embedding_dim, *args, process_group... class ParallelEmbeddings (line 241) | class ParallelEmbeddings(nn.Module): method __init__ (line 242) | def __init__( method forward (line 273) | def forward(self, input_ids, position_ids=None, combine_batch_seqlen_d... FILE: mamba_ssm/models/config_mamba.py class MambaConfig (line 5) | class MambaConfig: FILE: mamba_ssm/models/mixer_seq_simple.py function create_block (line 29) | def create_block( function _init_weights (line 86) | def _init_weights( class MixerModel (line 118) | class MixerModel(nn.Module): method __init__ (line 119) | def __init__( method allocate_inference_cache (line 184) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 190) | def forward(self, input_ids, inference_params=None, **mixer_kwargs): class MambaLMHeadModel (line 215) | class MambaLMHeadModel(nn.Module, GenerationMixin): method __init__ (line 217) | def __init__( method tie_weights (line 267) | def tie_weights(self): method allocate_inference_cache (line 271) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method forward (line 274) | def forward(self, input_ids, position_ids=None, inference_params=None,... method from_pretrained (line 287) | def from_pretrained(cls, pretrained_model_name, device=None, dtype=Non... method save_pretrained (line 294) | def save_pretrained(self, save_directory): FILE: mamba_ssm/modules/block.py class Block (line 10) | class Block(nn.Module): method __init__ (line 11) | def __init__( method forward (line 42) | def forward( method allocate_inference_cache (line 90) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... FILE: mamba_ssm/modules/mamba2.py class Mamba2 (line 37) | class Mamba2(nn.Module, PyTorchModelHubMixin): method __init__ (line 38) | def __init__( method forward (line 154) | def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, infer... method step (line 278) | def step(self, hidden_states, conv_state, ssm_state): method allocate_inference_cache (line 345) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method _get_states_from_cache (line 357) | def _get_states_from_cache(self, inference_params, batch_size, initial... FILE: mamba_ssm/modules/mamba2_simple.py class Mamba2Simple (line 24) | class Mamba2Simple(nn.Module): method __init__ (line 25) | def __init__( method forward (line 124) | def forward(self, u, seq_idx=None): FILE: mamba_ssm/modules/mamba3.py class Mamba3 (line 20) | class Mamba3(nn.Module): method __init__ (line 21) | def __init__( method forward (line 127) | def forward(self, u, seq_idx=None, cu_seqlens=None, inference_params=N... method _preprocess (line 249) | def _preprocess(self, A_proj, dd_dt, B, C, x, z, trap_proj, angle_proj): method _postprocess (line 272) | def _postprocess(self, y, outpj, z, zpj, headdim): method step (line 282) | def step(self, u, angle_state, ssm_state, k_state, v_state, **kwargs): method allocate_inference_cache (line 409) | def allocate_inference_cache(self, batch_size, max_seqlen, device=None... method _get_states_from_cache (line 451) | def _get_states_from_cache(self, inference_params, batch_size, initial... FILE: mamba_ssm/modules/mamba_simple.py class Mamba (line 31) | class Mamba(nn.Module): method __init__ (line 32) | def __init__( method forward (line 119) | def forward(self, hidden_states, inference_params=None): method step (line 208) | def step(self, hidden_states, conv_state, ssm_state): method allocate_inference_cache (line 255) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method _get_states_from_cache (line 268) | def _get_states_from_cache(self, inference_params, batch_size, initial... FILE: mamba_ssm/modules/mha.py function _update_kv_cache (line 26) | def _update_kv_cache(kv, inference_params, layer_idx): class MHA (line 44) | class MHA(nn.Module): method __init__ (line 47) | def __init__( method allocate_inference_cache (line 110) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): method _update_kv_cache (line 124) | def _update_kv_cache(self, kv, inference_params): method _apply_rotary_update_kvcache_attention (line 129) | def _apply_rotary_update_kvcache_attention(self, q, kv, inference_para... method _update_kvcache_attention (line 167) | def _update_kvcache_attention(self, q, kv, inference_params): method forward (line 201) | def forward(self, x, inference_params=None): FILE: mamba_ssm/modules/mlp.py class GatedMLP (line 6) | class GatedMLP(nn.Module): method __init__ (line 7) | def __init__( method forward (line 29) | def forward(self, x): FILE: mamba_ssm/modules/ssd_minimal.py function segsum_unstable (line 14) | def segsum_unstable(x): function segsum (line 23) | def segsum(x): function ssd_minimal_discrete (line 34) | def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): function test_correctness (line 82) | def test_correctness(): FILE: mamba_ssm/ops/cute/mamba3/mamba3_step_fn.py function transpose_view (line 22) | def transpose_view(a: cute.Tensor) -> cute.Tensor: function select (line 28) | def select(a: cute.Tensor, mode: List[int]) -> cute.Tensor: function get_gmem_tiled_copy (line 33) | def get_gmem_tiled_copy(dtype: Type[cutlass.Numeric], major_mode_size: i... class Mamba3Step (line 48) | class Mamba3Step(): method __init__ (line 49) | def __init__(self, tile_D: int, dstate: int, mimo: int = 1, num_warps:... method _setup_smem_layouts (line 59) | def _setup_smem_layouts(self): method _setup_gmem_tiled_copy (line 66) | def _setup_gmem_tiled_copy(self, ): method __call__ (line 86) | def __call__( method kernel (line 227) | def kernel( function mamba3_step_fn (line 566) | def mamba3_step_fn( function selective_state_update_fused_ref_v2 (line 741) | def selective_state_update_fused_ref_v2( function _bytes_of (line 816) | def _bytes_of(t): FILE: mamba_ssm/ops/selective_scan_interface.py class SelectiveScanFn (line 23) | class SelectiveScanFn(torch.autograd.Function): method forward (line 26) | def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, d... method backward (line 59) | def backward(ctx, dout, *args): function rms_norm_forward (line 86) | def rms_norm_forward( function selective_scan_fn (line 106) | def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None... function selective_scan_ref (line 115) | def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=Non... class MambaInnerFn (line 184) | class MambaInnerFn(torch.autograd.Function): method forward (line 188) | def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_... method backward (line 282) | def backward(ctx, dout): function mamba_inner_fn (line 373) | def mamba_inner_fn( function mamba_inner_ref (line 384) | def mamba_inner_ref( FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo.py class _Mamba3Function (line 24) | class _Mamba3Function(torch.autograd.Function): method forward (line 28) | def forward( method backward (line 88) | def backward(ctx, dout, *args) -> tuple: function mamba3_mimo (line 154) | def mamba3_mimo( FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py function mamba_mimo_bwd_fwd (line 41) | def mamba_mimo_bwd_fwd( function mamba_mimo_bwd_bwd (line 500) | def mamba_mimo_bwd_bwd( function mamba_mimo_bwd_combined (line 1146) | def mamba_mimo_bwd_combined( FILE: mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py function mamba_mimo_fwd (line 38) | def mamba_mimo_fwd( function mamba_mimo_forward (line 413) | def mamba_mimo_forward(q, k, v, FILE: mamba_ssm/ops/triton/angle_cumsum.py class AngleDtFn (line 12) | class AngleDtFn(torch.autograd.Function): method forward (line 14) | def forward(ctx, method backward (line 27) | def backward(ctx, grad_out: torch.Tensor): function angle_dt (line 37) | def angle_dt(angle: torch.Tensor, function cumsum_kernel (line 45) | def cumsum_kernel( function angle_dt_fwd_kernel (line 87) | def angle_dt_fwd_kernel( function angle_dt_bwd_kernel (line 194) | def angle_dt_bwd_kernel( function apply_angle_dt_fwd (line 307) | def apply_angle_dt_fwd( function apply_angle_dt_bwd (line 395) | def apply_angle_dt_bwd( function apply_cumsum (line 504) | def apply_cumsum( function apply_angle_dt_reference (line 541) | def apply_angle_dt_reference( function test_correctness (line 561) | def test_correctness(): function test_cumsum_correctness (line 587) | def test_cumsum_correctness(): function test_backward_correctness (line 611) | def test_backward_correctness(): function benchmark_angle_dt (line 647) | def benchmark_angle_dt(): function benchmark_angle_dt_backward (line 706) | def benchmark_angle_dt_backward(): FILE: mamba_ssm/ops/triton/k_activations.py function _swiglu_fwd_kernel (line 23) | def _swiglu_fwd_kernel( function _swiglu_fwd (line 46) | def _swiglu_fwd(xy, out=None): function _swiglu_bwd_kernel (line 78) | def _swiglu_bwd_kernel( function _swiglu_bwd (line 119) | def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): class SwiGLU (line 158) | class SwiGLU(torch.autograd.Function): method forward (line 161) | def forward(ctx, xy): method backward (line 166) | def backward(ctx, dout): FILE: mamba_ssm/ops/triton/layer_norm.py function layer_norm_ref (line 22) | def layer_norm_ref( function rms_norm_ref (line 77) | def rms_norm_ref( function config_prune (line 130) | def config_prune(configs): function _layer_norm_fwd_1pass_kernel (line 181) | def _layer_norm_fwd_1pass_kernel( function _layer_norm_fwd (line 291) | def _layer_norm_fwd( function _layer_norm_bwd_kernel (line 436) | def _layer_norm_bwd_kernel( function _layer_norm_bwd (line 589) | def _layer_norm_bwd( class LayerNormFn (line 728) | class LayerNormFn(torch.autograd.Function): method forward (line 730) | def forward( method backward (line 828) | def backward(ctx, dy, *args): function layer_norm_fn (line 888) | def layer_norm_fn( function rms_norm_fn (line 922) | def rms_norm_fn( class RMSNorm (line 955) | class RMSNorm(torch.nn.Module): method __init__ (line 957) | def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, ... method reset_parameters (line 969) | def reset_parameters(self): method forward (line 972) | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=Fa... class LayerNormLinearFn (line 985) | class LayerNormLinearFn(torch.autograd.Function): method forward (line 988) | def forward( method backward (line 1047) | def backward(ctx, dout, *args): function layer_norm_linear_fn (line 1092) | def layer_norm_linear_fn( FILE: mamba_ssm/ops/triton/layernorm_gated.py function rms_norm_ref (line 18) | def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, nor... function _layer_norm_fwd_1pass_kernel (line 45) | def _layer_norm_fwd_1pass_kernel( function _layer_norm_fwd (line 108) | def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=N... function _layer_norm_bwd_kernel (line 155) | def _layer_norm_bwd_kernel( function _layer_norm_bwd (line 271) | def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_... class LayerNormFn (line 338) | class LayerNormFn(torch.autograd.Function): method forward (line 341) | def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, n... method backward (line 369) | def backward(ctx, dy): function layernorm_fn (line 380) | def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, nor... function rmsnorm_fn (line 384) | def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_... class LayerNorm (line 388) | class LayerNorm(torch.nn.Module): method __init__ (line 390) | def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before... method reset_parameters (line 404) | def reset_parameters(self): method forward (line 408) | def forward(self, x, z=None): class RMSNorm (line 415) | class RMSNorm(torch.nn.Module): method __init__ (line 417) | def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before... method reset_parameters (line 430) | def reset_parameters(self): method forward (line 433) | def forward(self, x, z=None): FILE: mamba_ssm/ops/triton/mamba3/angle_dt.py function angle_dt_fwd_kernel (line 24) | def angle_dt_fwd_kernel( function angle_dt_fwd (line 125) | def angle_dt_fwd( function angle_dt_bwd_kernel (line 232) | def angle_dt_bwd_kernel( function angle_dt_bwd (line 345) | def angle_dt_bwd( FILE: mamba_ssm/ops/triton/mamba3/mamba3_mimo_rotary_step.py function rotary_qk_inference_kernel (line 16) | def rotary_qk_inference_kernel( function apply_rotary_qk_inference_fwd (line 151) | def apply_rotary_qk_inference_fwd( function apply_rotary_qk_inference_reference (line 239) | def apply_rotary_qk_inference_reference( function test_correctness_qk_inference (line 327) | def test_correctness_qk_inference(): FILE: mamba_ssm/ops/triton/mamba3/mamba3_mimo_utils.py function bwd_dadt_cumsum_fused_kernel (line 37) | def bwd_dadt_cumsum_fused_kernel( function bwd_segsum_dadt_kernel (line 134) | def bwd_segsum_dadt_kernel( function bwd_dtrap_ddt_kernel (line 225) | def bwd_dtrap_ddt_kernel( function dacs_segsum_kernel (line 349) | def dacs_segsum_kernel( function bwd_dadt_fused_triton (line 407) | def bwd_dadt_fused_triton( function bwd_dtrap_ddt_triton (line 450) | def bwd_dtrap_ddt_triton( function compute_dacs_segsum_triton (line 478) | def compute_dacs_segsum_triton( function bwd_segsum_ddt_from_dSSdA_ref (line 508) | def bwd_segsum_ddt_from_dSSdA_ref( function bwd_ddt_from_ddA_cs_rev_ref (line 528) | def bwd_ddt_from_ddA_cs_rev_ref( function bwd_ddt_from_ddA_cs_ref (line 545) | def bwd_ddt_from_ddA_cs_ref( function compute_dtrap_ddt_ref (line 560) | def compute_dtrap_ddt_ref(dfactor: torch.Tensor, function compute_dacs_segsum_ref (line 582) | def compute_dacs_segsum_ref(da: torch.Tensor, # (B, H, S) function test_bwd_ddt_fused_correctness (line 606) | def test_bwd_ddt_fused_correctness(): function test_dtrap_ddt_correctness (line 651) | def test_dtrap_ddt_correctness(): function test_dacs_segsum_correctness (line 696) | def test_dacs_segsum_correctness(): function benchmark_bwd_ddt (line 729) | def benchmark_bwd_ddt(): function benchmark_dacs_segsum (line 792) | def benchmark_dacs_segsum(): function benchmark_dtrap_ddt (line 829) | def benchmark_dtrap_ddt(): FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py function mamba3_siso_bwd_kernel_dzdo (line 32) | def mamba3_siso_bwd_kernel_dzdo( function compute_dzdo (line 114) | def compute_dzdo( function mamba3_siso_bwd_kernel_dqkv (line 202) | def mamba3_siso_bwd_kernel_dqkv( function compute_dqkv (line 614) | def compute_dqkv( function mamba3_siso_bwd_kernel_rotary_bias_angles (line 811) | def mamba3_siso_bwd_kernel_rotary_bias_angles( function mamba3_siso_bwd_kernel_dk_state_post (line 1055) | def mamba3_siso_bwd_kernel_dk_state_post( function compute_dqktheta (line 1159) | def compute_dqktheta( function apply_dk_state_post (line 1338) | def apply_dk_state_post( function mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states (line 1418) | def mamba3_siso_bwd_kernel_ddt_dtrap_dinput_states( function compute_ddt_dtrap_dinput_states (line 1611) | def compute_ddt_dtrap_dinput_states( function _alloc_fn (line 1771) | def _alloc_fn(size: int, alignment: int, stream: Optional[int]): FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py function _triton_alloc_fn (line 21) | def _triton_alloc_fn(size: int, alignment: int, stream: Optional[int]): class Mamba3Output (line 34) | class Mamba3Output: class _Mamba3Function (line 50) | class _Mamba3Function(torch.autograd.Function): method forward (line 54) | def forward( method backward (line 153) | def backward( function mamba3_siso_combined (line 291) | def mamba3_siso_combined( FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py function mamba3_siso_fwd_kernel (line 29) | def mamba3_siso_fwd_kernel( function _alloc_fn (line 434) | def _alloc_fn(size: int, alignment: int, stream: Optional[int]): function mamba3_siso_fwd (line 439) | def mamba3_siso_fwd( FILE: mamba_ssm/ops/triton/mamba3/mamba3_siso_step.py function mamba3_siso_step_kernel (line 27) | def mamba3_siso_step_kernel( function _alloc_fn (line 228) | def _alloc_fn(size: int, alignment: int, stream: Optional[int]): function mamba3_siso_step (line 233) | def mamba3_siso_step( FILE: mamba_ssm/ops/triton/mamba3/utils.py function cos_approx (line 14) | def cos_approx(x): function sin_approx (line 34) | def sin_approx(x): function tanh_approx (line 53) | def tanh_approx(x): function sech2_approx (line 72) | def sech2_approx(x): function sigmoid_approx (line 92) | def sigmoid_approx(x): function silu (line 117) | def silu(x): FILE: mamba_ssm/ops/triton/selective_state_update.py function _selective_scan_update_kernel (line 24) | def _selective_scan_update_kernel( function selective_state_update (line 135) | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bia... function selective_state_update_ref (line 224) | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt... FILE: mamba_ssm/ops/triton/softplus.py function softplus (line 10) | def softplus(dt): function softplus (line 14) | def softplus(dt): FILE: mamba_ssm/ops/triton/ssd_bmm.py function init_to_zero (line 18) | def init_to_zero(names): function _bmm_chunk_fwd_kernel (line 37) | def _bmm_chunk_fwd_kernel( function _bmm_chunk_bwd_kernel (line 111) | def _bmm_chunk_bwd_kernel( function _bmm_chunk_fwd (line 165) | def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_... function _bmm_chunk_bwd (line 213) | def _bmm_chunk_bwd(a, dout, residual=None, out=None): FILE: mamba_ssm/ops/triton/ssd_chunk_scan.py function init_to_zero (line 28) | def init_to_zero(names): function _chunk_scan_fwd_kernel (line 49) | def _chunk_scan_fwd_kernel( function _chunk_scan_fwd_kernel_wip (line 195) | def _chunk_scan_fwd_kernel_wip( function _chunk_scan_bwd_dz_kernel (line 349) | def _chunk_scan_bwd_dz_kernel( function _chunk_scan_bwd_dstates_kernel (line 449) | def _chunk_scan_bwd_dstates_kernel( function _chunk_scan_bwd_dc_kernel (line 530) | def _chunk_scan_bwd_dc_kernel( function _chunk_scan_bwd_dx_kernel (line 641) | def _chunk_scan_bwd_dx_kernel( function _chunk_scan_bwd_dcb_kernel (line 774) | def _chunk_scan_bwd_dcb_kernel( function _chunk_scan_bwd_ddAcs_unstable_kernel (line 889) | def _chunk_scan_bwd_ddAcs_unstable_kernel( function _chunk_scan_bwd_ddAcs_stable_kernel_old (line 979) | def _chunk_scan_bwd_ddAcs_stable_kernel_old( function _chunk_scan_bwd_ddAcs_stable_kernel (line 1098) | def _chunk_scan_bwd_ddAcs_stable_kernel( function _chunk_scan_bwd_ddAcs_prev_kernel (line 1198) | def _chunk_scan_bwd_ddAcs_prev_kernel( function _chunk_scan_fwd (line 1259) | def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq... function _chunk_scan_fwd_wip (line 1311) | def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=No... function _chunk_scan_bwd_dz (line 1363) | def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=No... function _chunk_scan_bwd_dstates (line 1423) | def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): function _chunk_scan_bwd_dC (line 1451) | def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=Non... function _chunk_scan_bwd_dcb (line 1514) | def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, n... function _chunk_scan_bwd_dx (line 1565) | def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): function _chunk_scan_bwd_ddAcs_unstable (line 1618) | def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtra... function _chunk_scan_bwd_ddAcs_stable_old (line 1667) | def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): function _chunk_scan_bwd_ddAcs_stable (line 1700) | def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): function _chunk_scan_bwd_ddAcs_prev (line 1732) | def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=... class ChunkScanFn (line 1764) | class ChunkScanFn(torch.autograd.Function): method forward (line 1767) | def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): method backward (line 1798) | def backward(ctx, dout): function chunk_scan (line 1828) | def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): function chunk_scan_ref (line 1846) | def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): FILE: mamba_ssm/ops/triton/ssd_chunk_state.py function init_to_zero (line 24) | def init_to_zero(names): function _chunk_cumsum_fwd_kernel (line 40) | def _chunk_cumsum_fwd_kernel( function _chunk_cumsum_bwd_kernel (line 102) | def _chunk_cumsum_bwd_kernel( function _chunk_state_fwd_kernel (line 192) | def _chunk_state_fwd_kernel( function _chunk_state_bwd_dx_kernel (line 286) | def _chunk_state_bwd_dx_kernel( function _chunk_state_bwd_db_kernel (line 398) | def _chunk_state_bwd_db_kernel( function _chunk_state_bwd_ddAcs_stable_kernel (line 528) | def _chunk_state_bwd_ddAcs_stable_kernel( function _chunk_state_varlen_kernel (line 640) | def _chunk_state_varlen_kernel( function _chunk_cumsum_fwd (line 718) | def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False... function _chunk_cumsum_bwd (line 744) | def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=Fal... function _chunk_state_fwd (line 812) | def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, sta... function _chunk_state_bwd_dx (line 845) | def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None): function _chunk_state_bwd_db (line 902) | def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None,... function _chunk_state_bwd_ddAcs_stable (line 972) | def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=... function chunk_state_varlen (line 1018) | def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states): class ChunkStateFn (line 1047) | class ChunkStateFn(torch.autograd.Function): method forward (line 1050) | def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True): method backward (line 1067) | def backward(ctx, dstates): function chunk_state (line 1081) | def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True): function chunk_state_ref (line 1094) | def chunk_state_ref(B, x, dt, dA_cumsum): FILE: mamba_ssm/ops/triton/ssd_combined.py function init_to_zero (line 55) | def init_to_zero(names): function ensure_stride (line 59) | def ensure_stride(inp): function _chunk_scan_chunk_state_bwd_dx_kernel (line 93) | def _chunk_scan_chunk_state_bwd_dx_kernel( function _chunk_scan_chunk_state_bwd_dx (line 262) | def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstate... function _mamba_chunk_scan_combined_fwd (line 343) | def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z... function _mamba_chunk_scan_combined_bwd (line 396) | def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size... function selective_scan_bwd (line 516) | def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None): class MambaChunkScanCombinedFn (line 593) | class MambaChunkScanCombinedFn(torch.autograd.Function): method forward (line 596) | def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=N... method backward (line 616) | def backward(ctx, dout, *args): function mamba_chunk_scan_combined (line 624) | def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None... function mamba_chunk_scan (line 646) | def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias... function ssd_chunk_scan_combined_ref (line 683) | def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=No... function ssd_selective_scan (line 724) | def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_... function mamba_conv1d_scan_ref (line 775) | def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_... class MambaSplitConv1dScanCombinedFn (line 816) | class MambaSplitConv1dScanCombinedFn(torch.autograd.Function): method forward (line 820) | def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, ch... method backward (line 898) | def backward(ctx, dout, *args): function mamba_split_conv1d_scan_combined (line 978) | def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias,... function mamba_split_conv1d_scan_ref (line 1000) | def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_b... FILE: mamba_ssm/ops/triton/ssd_state_passing.py function _state_passing_fwd_kernel (line 30) | def _state_passing_fwd_kernel( function _state_passing_bwd_kernel (line 102) | def _state_passing_bwd_kernel( function _state_passing_fwd (line 196) | def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq... function _state_passing_bwd (line 227) | def _state_passing_bwd( class StatePassingFn (line 286) | class StatePassingFn(torch.autograd.Function): method forward (line 289) | def forward(ctx, states, dA_chunk_cumsum, initial_states=None): method backward (line 300) | def backward(ctx, dout, dfinal_states): function state_passing (line 314) | def state_passing(states, dA_chunk_cumsum, initial_states=None): function state_passing_ref (line 327) | def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): FILE: mamba_ssm/utils/determinism.py function use_deterministic_mode (line 21) | def use_deterministic_mode(): function set_deterministic_mode (line 30) | def set_deterministic_mode(value): function _estimate_config_cost (line 35) | def _estimate_config_cost(cfg): function _filter_configs_by_block_sizes (line 44) | def _filter_configs_by_block_sizes(configs): function autotune_configs (line 59) | def autotune_configs(configs): function alloc_tile_workspace (line 80) | def alloc_tile_workspace(base_shape, tile_dim, dtype, device, determinis... function finalize_tile_workspace (line 91) | def finalize_tile_workspace(tensor, deterministic): FILE: mamba_ssm/utils/generation.py class InferenceParams (line 18) | class InferenceParams: method reset (line 29) | def reset(self, max_seqlen, max_batch_size): function modify_logits_for_min_p_filtering (line 37) | def modify_logits_for_min_p_filtering(logits, min_p): function modify_logits_for_top_k_filtering (line 45) | def modify_logits_for_top_k_filtering(logits, top_k): function modify_logits_for_top_p_filtering (line 53) | def modify_logits_for_top_p_filtering(logits, top_p): function modify_logit_for_repetition_penalty (line 69) | def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repe... function sample (line 83) | def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): function decode (line 121) | def decode( class GenerationMixin (line 246) | class GenerationMixin: method allocate_inference_cache (line 247) | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None,... method generate (line 250) | def generate( class DecodingCGCache (line 271) | class DecodingCGCache: function update_graph_cache (line 283) | def update_graph_cache( function capture_graph (line 342) | def capture_graph( FILE: mamba_ssm/utils/hf.py function load_config_hf (line 9) | def load_config_hf(model_name): function load_state_dict_hf (line 14) | def load_state_dict_hf(model_name, device=None, dtype=None): FILE: mamba_ssm/utils/torch.py function custom_amp_decorator (line 5) | def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): FILE: setup.py function get_platform (line 47) | def get_platform(): function get_cuda_bare_metal_version (line 62) | def get_cuda_bare_metal_version(cuda_dir): function get_hip_version (line 73) | def get_hip_version(rocm_dir): function get_torch_hip_version (line 94) | def get_torch_hip_version(): function check_if_hip_home_none (line 102) | def check_if_hip_home_none(global_option: str) -> None: function check_if_cuda_home_none (line 113) | def check_if_cuda_home_none(global_option: str) -> None: function append_nvcc_threads (line 125) | def append_nvcc_threads(nvcc_extra_args): function get_package_version (line 268) | def get_package_version(): function get_wheel_url (line 279) | def get_wheel_url(): class CachedWheelsCommand (line 328) | class CachedWheelsCommand(_bdist_wheel): method run (line 336) | def run(self): FILE: tests/benchmark_determinism_kernels.py function _reset_peak_memory (line 18) | def _reset_peak_memory() -> None: function _peak_memory_mb (line 25) | def _peak_memory_mb(fn, *, warmup: int = 3) -> float: function make_tensors (line 35) | def make_tensors(*, batch: int, seqlen: int, nheads: int, headdim: int, ... function get_benchmarks (line 57) | def get_benchmarks(t: dict[str, torch.Tensor], *, ngroups: int): function _run_one (line 84) | def _run_one(fn, *, deterministic: bool, warmup: int, rep: int): function main (line 91) | def main() -> None: FILE: tests/ops/cute/test_mamba3_mimo_step.py function _require_cuda_and_kernel_deps (line 46) | def _require_cuda_and_kernel_deps() -> None: function _mamba3_cls (line 53) | def _mamba3_cls(): function _kernel_deps (line 60) | def _kernel_deps() -> None: class InferenceParams (line 67) | class InferenceParams: method reset (line 78) | def reset(self, max_seqlen, max_batch_size): class RunOutputs (line 87) | class RunOutputs: function _case_config (line 96) | def _case_config(*, is_outproj_norm: bool) -> dict: function _diff_stats (line 113) | def _diff_stats(actual: Tensor, expected: Tensor) -> str: function _assert_close (line 118) | def _assert_close( function _run_case (line 141) | def _run_case(*, is_outproj_norm: bool) -> RunOutputs: function test_step_matches_forward_fp32 (line 224) | def test_step_matches_forward_fp32(is_outproj_norm: bool) -> None: function run_step_benchmark (line 253) | def run_step_benchmark(*, is_outproj_norm: bool) -> None: FILE: tests/ops/test_selective_scan.py function test_selective_scan (line 38) | def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_... function test_mamba_inner_fn (line 160) | def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wty... FILE: tests/ops/tilelang/test_mamba3_mimo.py function _require_cuda_and_kernel_deps (line 54) | def _require_cuda_and_kernel_deps() -> None: function mods (line 62) | def mods() -> SimpleNamespace: function max_rel_err (line 77) | def max_rel_err(ours: Tensor, ref: Tensor, eps: float = 1e-5) -> float: function assert_stable_rel (line 85) | def assert_stable_rel( function build_inputs (line 114) | def build_inputs( function make_smoke_inputs (line 181) | def make_smoke_inputs( function grads_to_dA (line 292) | def grads_to_dA(grad_dA_cs: Tensor, grad_dA_cs_rev: Tensor, chunk_size: ... function mamba3_MIMO_step_ref (line 307) | def mamba3_MIMO_step_ref( function apply_angle_dt_reference (line 445) | def apply_angle_dt_reference( function mamba3_MIMO_chunk_ref (line 455) | def mamba3_MIMO_chunk_ref( function run_ref_backward_fp32 (line 638) | def run_ref_backward_fp32( function test_mamba3_MIMO_chunk_ref_matches_step_ref (line 751) | def test_mamba3_MIMO_chunk_ref_matches_step_ref() -> None: function test_fused_chunk_linear_attn_fwd_relative_error_lt_10pct (line 833) | def test_fused_chunk_linear_attn_fwd_relative_error_lt_10pct( function test_fused_chunk_linear_attn_fwd_return_state_relative_error_lt_10pct (line 898) | def test_fused_chunk_linear_attn_fwd_return_state_relative_error_lt_10pct( function test_fused_chunk_linear_attn_fwd_prereduce_relative_error_lt_10pct (line 977) | def test_fused_chunk_linear_attn_fwd_prereduce_relative_error_lt_10pct( function test_mamba_mimo_bwd_combined_relative_errors_lt_10pct (line 1044) | def test_mamba_mimo_bwd_combined_relative_errors_lt_10pct( function test_mamba_mimo_bwd_combined_prereduce_relative_errors_lt_10pct (line 1124) | def test_mamba_mimo_bwd_combined_prereduce_relative_errors_lt_10pct( function test_mamba_mimo_smoke_forward_backward (line 1211) | def test_mamba_mimo_smoke_forward_backward(mods: SimpleNamespace) -> None: FILE: tests/ops/triton/test_layernorm_gated.py function test_layer_norm_gated (line 29) | def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm,... FILE: tests/ops/triton/test_mamba3_siso.py function _segsum (line 21) | def _segsum(x: torch.Tensor) -> torch.Tensor: function mamba3_siso_step_ref (line 33) | def mamba3_siso_step_ref( function mamba3_siso_fwd_ref (line 148) | def mamba3_siso_fwd_ref( function detach_clone (line 346) | def detach_clone(*args): function relative_error (line 351) | def relative_error( function create_mamba3_siso_inputs (line 406) | def create_mamba3_siso_inputs( function test_mamba3_siso_step (line 487) | def test_mamba3_siso_step(nheads_qk=4, has_Z=True, has_D=True): function test_mamba3_siso_combined_batched (line 561) | def test_mamba3_siso_combined_batched(nheads_qk=4, has_Z=True, has_D=Tru... function test_mamba3_siso_combined_varlen (line 675) | def test_mamba3_siso_combined_varlen(nheads_qk=4, has_Z=True, has_D=True... function test_mamba3_siso_step_ref_vs_fwd_ref (line 843) | def test_mamba3_siso_step_ref_vs_fwd_ref(nheads_qk=4, has_Z=True, has_D=... FILE: tests/ops/triton/test_selective_state_update.py function test_selective_state_update (line 22) | def test_selective_state_update(dim, dstate, has_z, itype): function test_selective_state_update_with_heads (line 66) | def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, ... function test_selective_state_update_with_batch_indices (line 112) | def test_selective_state_update_with_batch_indices(dim, dstate, has_z, i... function test_selective_state_update_with_heads_with_batch_indices (line 161) | def test_selective_state_update_with_heads_with_batch_indices(dim, dstat... FILE: tests/ops/triton/test_ssd.py function detach_clone (line 20) | def detach_clone(*args): function test_chunk_state_varlen (line 30) | def test_chunk_state_varlen(chunk_size, ngroups, dtype): FILE: tests/test_determinism.py function _set_deterministic (line 9) | def _set_deterministic(enabled: bool) -> None: function _set_seeds (line 15) | def _set_seeds(seed: int) -> None: function _max_abs_diff (line 20) | def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: function _make_inputs (line 24) | def _make_inputs( function _run_case_outputs (line 85) | def _run_case_outputs( function _kernel_is_reproducible (line 161) | def _kernel_is_reproducible(case: str, headdim: int, dstate: int, d_has_... function _kernel_close_to_default (line 173) | def _kernel_close_to_default(case: str, headdim: int, dstate: int, d_has... function test_kernel_reproducible (line 186) | def test_kernel_reproducible(case: str, headdim: int, dstate: int): function test_combined_kernel_reproducible (line 194) | def test_combined_kernel_reproducible(case: str, d_has_hdim: bool, headd... function test_kernel_close_to_default (line 202) | def test_kernel_close_to_default(case: str, headdim: int, dstate: int): function test_combined_kernel_close_to_default (line 210) | def test_combined_kernel_close_to_default(case: str, d_has_hdim: bool, h... function test_default_mode_is_not_reproducible (line 215) | def test_default_mode_is_not_reproducible(): function test_mamba2_fwd_bwd_deterministic_reproducible (line 265) | def test_mamba2_fwd_bwd_deterministic_reproducible(): function test_mamba2_fwd_bwd_deterministic_close_to_default (line 308) | def test_mamba2_fwd_bwd_deterministic_close_to_default(): FILE: tests/test_generation.py function test_generation (line 13) | def test_generation(): function test_generation_varlen (line 46) | def test_generation_varlen(): function test_generation_varlen_with_padding (line 115) | def test_generation_varlen_with_padding():