SYMBOL INDEX (157 symbols across 26 files) FILE: native_sparse_attention/infer/inference_func.py function compress_infer (line 27) | def compress_infer( function compressed_attention_infer (line 90) | def compressed_attention_infer( function topk_sparse_attention_infer (line 148) | def topk_sparse_attention_infer( function sliding_window_attention_infer (line 175) | def sliding_window_attention_infer( FILE: native_sparse_attention/infer/nsa_inference.py function nsa_infer (line 24) | def nsa_infer( FILE: native_sparse_attention/model/toy_llama.py class ToyLlamaConfig (line 22) | class ToyLlamaConfig: class InferenceConfig (line 47) | class InferenceConfig: class RMSNorm (line 53) | class RMSNorm(nn.Module): method __init__ (line 54) | def __init__(self, hidden_size: int, eps: float = 1e-6): method forward (line 59) | def forward(self, hidden_states: torch.Tensor): class FFN (line 67) | class FFN(nn.Module): method __init__ (line 68) | def __init__(self, hidden_size: int, intermediate_size: int): method forward (line 77) | def forward(self, x): class ToyLlamaLayer (line 82) | class ToyLlamaLayer(nn.Module): method __init__ (line 83) | def __init__( method forward (line 112) | def forward(self, x, cu_seqlens): method inference (line 118) | def inference(self, x, cu_seqlens, step, kv_cache): class ToyLlama (line 124) | class ToyLlama(nn.Module): method __init__ (line 125) | def __init__( method forward (line 163) | def forward( method inference (line 180) | def inference( method generate (line 212) | def generate( FILE: native_sparse_attention/model/toy_nsa_llama.py class ToyNSALlamaConfig (line 22) | class ToyNSALlamaConfig: class InferenceConfig (line 56) | class InferenceConfig: class RMSNorm (line 62) | class RMSNorm(nn.Module): method __init__ (line 63) | def __init__(self, hidden_size: int, eps: float = 1e-6): method forward (line 68) | def forward(self, hidden_states: torch.Tensor): class FFN (line 76) | class FFN(nn.Module): method __init__ (line 77) | def __init__(self, hidden_size: int, intermediate_size: int): method forward (line 86) | def forward(self, x): class ToyNSALlamaLayer (line 91) | class ToyNSALlamaLayer(nn.Module): method __init__ (line 92) | def __init__( method forward (line 145) | def forward(self, x, cu_seqlens): method inference (line 151) | def inference(self, x, cu_seqlens, step, kv_cache): class ToyNSALlama (line 157) | class ToyNSALlama(nn.Module): method __init__ (line 158) | def __init__( method forward (line 206) | def forward( method inference (line 223) | def inference( method generate (line 258) | def generate( FILE: native_sparse_attention/module/kv_cache.py class KVCache (line 21) | class KVCache: method __init__ (line 22) | def __init__( method reset (line 52) | def reset(self): method update_kv (line 56) | def update_kv( method _update_kv_prefill (line 78) | def _update_kv_prefill( method _update_kv_decode (line 104) | def _update_kv_decode( class NSACache (line 125) | class NSACache: method __init__ (line 144) | def __init__( method reset (line 223) | def reset(self): method prepare_compress (line 233) | def prepare_compress( method _prepare_compress_prefill (line 245) | def _prepare_compress_prefill( method _prepare_compress_decode (line 274) | def _prepare_compress_decode( method update_kv (line 304) | def update_kv( method _update_kv_prefill (line 338) | def _update_kv_prefill( method _update_kv_decode (line 401) | def _update_kv_decode( function _fill_kv_cache_kernel (line 453) | def _fill_kv_cache_kernel( function _fill_kv_cache (line 519) | def _fill_kv_cache( FILE: native_sparse_attention/module/native_sparse_attention.py class NativeSparseAttention (line 45) | class NativeSparseAttention(torch.nn.Module): method __init__ (line 65) | def __init__( method init_params (line 140) | def init_params(self): method forward (line 145) | def forward( method inference (line 241) | def inference( FILE: native_sparse_attention/module/rope.py class RopeConfig (line 28) | class RopeConfig: method __post_init__ (line 47) | def __post_init__(self): function rotate_half (line 53) | def rotate_half(x): class RotaryEmbedding (line 62) | class RotaryEmbedding(nn.Module): method __init__ (line 73) | def __init__( method _dynamic_frequency_update (line 94) | def _dynamic_frequency_update(self, position_ids, device): method generate_cos_sin (line 121) | def generate_cos_sin(self, x: torch.Tensor, position_ids): method generate_pos_embs (line 158) | def generate_pos_embs( method forward (line 198) | def forward(self, x, cu_seqlens, step=0, stride=1): FILE: native_sparse_attention/module/self_attention.py class SelfAttention (line 22) | class SelfAttention(torch.nn.Module): method __init__ (line 33) | def __init__( method init_params (line 70) | def init_params(self): method forward (line 74) | def forward( method inference (line 113) | def inference( FILE: native_sparse_attention/ops/torch/compress_key_value.py function avgpool_compress_torch (line 19) | def avgpool_compress_torch( function weightedpool_compress_torch (line 84) | def weightedpool_compress_torch( function linear_compress_torch (line 156) | def linear_compress_torch( FILE: native_sparse_attention/ops/torch/compressed_attention.py function transform_score (line 21) | def transform_score( function compressed_attention_torch (line 74) | def compressed_attention_torch( FILE: native_sparse_attention/ops/torch/compressed_attention_decode.py function transform_score (line 21) | def transform_score( function compressed_attention_decode (line 65) | def compressed_attention_decode( FILE: native_sparse_attention/ops/torch/topk_sparse_attention.py function topk_sparse_attention_torch (line 19) | def topk_sparse_attention_torch( FILE: native_sparse_attention/ops/triton/compressed_attention.py function forward_kernel (line 28) | def forward_kernel( function backward_sum_o_do (line 162) | def backward_sum_o_do( function backward_dkdv (line 206) | def backward_dkdv( function backward_dq (line 385) | def backward_dq( function _compressed_attention_fwd (line 538) | def _compressed_attention_fwd( function _compressed_attention_bwd (line 618) | def _compressed_attention_bwd( class CompressedAttention (line 783) | class CompressedAttention(torch.autograd.Function): method forward (line 785) | def forward( method backward (line 826) | def backward(ctx, do: torch.Tensor, *args) -> Any: function score_kernel (line 852) | def score_kernel( function _get_attention_score (line 954) | def _get_attention_score( function _transform_score_kernel (line 1030) | def _transform_score_kernel( function transform_score (line 1116) | def transform_score( function compressed_attention (line 1182) | def compressed_attention( FILE: native_sparse_attention/ops/triton/flash_attention.py function forward_kernel (line 27) | def forward_kernel( function backward_sum_o_do (line 169) | def backward_sum_o_do( function backward_dkdv (line 213) | def backward_dkdv( function backward_dq (line 400) | def backward_dq( function _flash_attention_fwd (line 563) | def _flash_attention_fwd( function _flash_attention_bwd (line 645) | def _flash_attention_bwd( class FlashAttention (line 831) | class FlashAttention(torch.autograd.Function): method forward (line 833) | def forward( method backward (line 870) | def backward(ctx, do: torch.Tensor, *args) -> Any: function flash_attention_varlen (line 895) | def flash_attention_varlen( FILE: native_sparse_attention/ops/triton/flash_attention_decode.py function decode_kernel (line 23) | def decode_kernel( function flash_attention_decode (line 142) | def flash_attention_decode( function torch_attention_decode (line 220) | def torch_attention_decode( FILE: native_sparse_attention/ops/triton/linear_compress.py function linear_compress_fwd_kernel (line 27) | def linear_compress_fwd_kernel( function linear_compress_bwd_kernel (line 140) | def linear_compress_bwd_kernel( class LinearCompress (line 309) | class LinearCompress(torch.autograd.Function): method forward (line 311) | def forward( method backward (line 420) | def backward(ctx, dy: torch.Tensor, *args) -> Any: function linear_compress (line 497) | def linear_compress( FILE: native_sparse_attention/ops/triton/topk_sparse_attention.py function forward_kernel (line 27) | def forward_kernel( function backward_sum_o_do (line 177) | def backward_sum_o_do( function count_kernel (line 221) | def count_kernel( function count_query (line 267) | def count_query( function pad_topk_idx_kernel (line 305) | def pad_topk_idx_kernel( function save_topk_idx_kernel (line 351) | def save_topk_idx_kernel( function reorder_topk_idx (line 404) | def reorder_topk_idx( function backward_dkdv (line 481) | def backward_dkdv( function backward_dq (line 659) | def backward_dq( function _topk_sparse_attention_fwd (line 828) | def _topk_sparse_attention_fwd( function _topk_sparse_attention_bwd (line 912) | def _topk_sparse_attention_bwd( class TopkSparseAttention (line 1112) | class TopkSparseAttention(torch.autograd.Function): method forward (line 1114) | def forward( method backward (line 1156) | def backward(ctx, do: torch.Tensor, *args) -> Any: function topk_sparse_attention (line 1182) | def topk_sparse_attention( FILE: native_sparse_attention/ops/triton/topk_sparse_attention_decode.py function forward_kernel (line 23) | def forward_kernel( function topk_sparse_attention_decode (line 151) | def topk_sparse_attention_decode( function torch_topk_sparse_attention_decode (line 240) | def torch_topk_sparse_attention_decode( function generate_topk_idx_example (line 301) | def generate_topk_idx_example( FILE: native_sparse_attention/ops/triton/utils.py function is_hopper_gpu (line 17) | def is_hopper_gpu(): function get_compressed_seqlens (line 25) | def get_compressed_seqlens( function get_num_warps_stages (line 40) | def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): FILE: native_sparse_attention/ops/triton/weighted_pool.py function sliding_pool_fwd_kernel (line 24) | def sliding_pool_fwd_kernel( function sliding_pool_dxdw_kernel (line 89) | def sliding_pool_dxdw_kernel( class SlidingWindowWeightedPool (line 182) | class SlidingWindowWeightedPool(torch.autograd.Function): method forward (line 184) | def forward( method backward (line 245) | def backward(ctx, dy, _): function weightedpool_compress (line 301) | def weightedpool_compress( function avgpool_compress (line 333) | def avgpool_compress( FILE: test/test_compress_key_value.py function benchmark (line 81) | def benchmark(N, H, D, provider): FILE: test/test_compressed_attention.py function benchmark (line 171) | def benchmark(N, H, D, provider): function benchmark (line 267) | def benchmark(N, H, D, provider): FILE: test/test_flash_attention.py function benchmark (line 124) | def benchmark(N, H, D, provider): function benchmark (line 176) | def benchmark(N, H, D, provider): FILE: test/test_linear_compress.py function test_linear_compress (line 21) | def test_linear_compress( function benchmark_fwdbwd (line 220) | def benchmark_fwdbwd(N, H, D, provider): FILE: test/test_nsa_module.py function benchmark (line 121) | def benchmark(N, provider): function benchmark (line 153) | def benchmark(N, provider): FILE: test/test_topk_sparse_attention.py function generate_topk_idx_example (line 34) | def generate_topk_idx_example( function benchmark (line 174) | def benchmark(N, H, D, K, provider): function benchmark (line 245) | def benchmark(N, H, D, K, provider):