SYMBOL INDEX (261 symbols across 11 files) FILE: gpu_mem_track.py function get_mem_space (line 30) | def get_mem_space(x): class MemTracker (line 37) | class MemTracker(object): method __init__ (line 46) | def __init__(self, detail=True, path='', verbose=False, device=0): method get_tensors (line 54) | def get_tensors(self): method get_tensor_usage (line 67) | def get_tensor_usage(self): method get_allocate_usage (line 71) | def get_allocate_usage(self): method clear_cache (line 74) | def clear_cache(self): method print_all_gpu_tensor (line 78) | def print_all_gpu_tensor(self, file=None): method track (line 82) | def track(self): FILE: inference.gemma.infini.py function generate_text_with_stateful_segments (line 14) | def generate_text_with_stateful_segments( FILE: infini_gemma/configuration_infini_gemma.py class GemmaConfig (line 4) | class GemmaConfig(OriginalGemmaConfig): method __init__ (line 5) | def __init__( FILE: infini_gemma/modeling_infini_gemma.py function debug_print (line 60) | def debug_print(*args): class InfiniBaseModelOutputWithPast (line 85) | class InfiniBaseModelOutputWithPast(ModelOutput): class InfiniCausalLMOutputWithPast (line 126) | class InfiniCausalLMOutputWithPast(ModelOutput): function _get_unpad_data (line 163) | def _get_unpad_data(attention_mask): class GemmaRMSNorm (line 177) | class GemmaRMSNorm(nn.Module): method __init__ (line 178) | def __init__(self, dim: int, eps: float = 1e-6): method _norm (line 183) | def _norm(self, x): method forward (line 186) | def forward(self, x): class GemmaRotaryEmbedding (line 197) | class GemmaRotaryEmbedding(nn.Module): method __init__ (line 198) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi... method forward (line 207) | def forward(self, x, position_ids, seq_len=None): function rotate_half (line 242) | def rotate_half(x): function apply_rotary_pos_emb (line 250) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di... class GemmaMLP (line 277) | class GemmaMLP(nn.Module): method __init__ (line 278) | def __init__(self, config): method forward (line 299) | def forward(self, x): function repeat_kv (line 304) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class GemmaAttention (line 318) | class GemmaAttention(nn.Module): method __init__ (line 322) | def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): method forward (line 371) | def forward( class GemmaFlashAttention2 (line 449) | class GemmaFlashAttention2(GemmaAttention): method __init__ (line 456) | def __init__(self, *args, **kwargs): method forward (line 465) | def forward( method _flash_attention_forward (line 562) | def _flash_attention_forward( method _upad_input (line 642) | def _upad_input( class GemmaSdpaAttention (line 689) | class GemmaSdpaAttention(GemmaAttention): method forward (line 697) | def forward( class GemmaInfiniAttention (line 783) | class GemmaInfiniAttention(GemmaAttention): method __init__ (line 784) | def __init__( method forward (line 796) | def forward( method _retrieve_from_memory (line 936) | def _retrieve_from_memory(self, query_states, memory, norm_term): method _update_memory (line 967) | def _update_memory(self, key_states, value_states, memory, norm_term): class GemmaDecoderLayer (line 1001) | class GemmaDecoderLayer(nn.Module): method __init__ (line 1002) | def __init__(self, config: GemmaConfig, layer_idx: int): method forward (line 1016) | def forward( class GemmaPreTrainedModel (line 1118) | class GemmaPreTrainedModel(PreTrainedModel): method _init_weights (line 1129) | def _init_weights(self, module): method _setup_cache (line 1140) | def _setup_cache( method _reset_cache (line 1162) | def _reset_cache(self): class GemmaModel (line 1246) | class GemmaModel(GemmaPreTrainedModel): method __init__ (line 1254) | def __init__(self, config: GemmaConfig): method get_input_embeddings (line 1274) | def get_input_embeddings(self): method set_input_embeddings (line 1277) | def set_input_embeddings(self, value): method forward (line 1282) | def forward( method _update_causal_mask (line 1441) | def _update_causal_mask( class GemmaForCausalLM (line 1521) | class GemmaForCausalLM(GemmaPreTrainedModel): method __init__ (line 1524) | def __init__(self, config): method get_input_embeddings (line 1533) | def get_input_embeddings(self): method set_input_embeddings (line 1536) | def set_input_embeddings(self, value): method get_output_embeddings (line 1539) | def get_output_embeddings(self): method set_output_embeddings (line 1542) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 1545) | def set_decoder(self, decoder): method get_decoder (line 1548) | def get_decoder(self): method forward (line 1556) | def forward( method prepare_inputs_for_generation (line 1661) | def prepare_inputs_for_generation( method _reorder_cache (line 1769) | def _reorder_cache(past_key_values, beam_idx): FILE: infini_llama/modeling_infini_llama.py function debug_print (line 59) | def debug_print(*args): class InfiniBaseModelOutputWithPast (line 84) | class InfiniBaseModelOutputWithPast(ModelOutput): class InfiniCausalLMOutputWithPast (line 125) | class InfiniCausalLMOutputWithPast(ModelOutput): function _get_unpad_data (line 162) | def _get_unpad_data(attention_mask): class LlamaRMSNorm (line 176) | class LlamaRMSNorm(nn.Module): method __init__ (line 177) | def __init__(self, hidden_size, eps=1e-6): method forward (line 185) | def forward(self, hidden_states): class LlamaRotaryEmbedding (line 196) | class LlamaRotaryEmbedding(nn.Module): method __init__ (line 197) | def __init__( method sin_cached (line 235) | def sin_cached(self): method cos_cached (line 243) | def cos_cached(self): method forward (line 251) | def forward(self, x, position_ids): class LlamaLinearScalingRotaryEmbedding (line 275) | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): method forward (line 278) | def forward(self, x, position_ids): class LlamaDynamicNTKScalingRotaryEmbedding (line 285) | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): method forward (line 288) | def forward(self, x, position_ids): function rotate_half (line 312) | def rotate_half(x): function apply_rotary_pos_emb (line 320) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di... class LlamaMLP (line 347) | class LlamaMLP(nn.Module): method __init__ (line 348) | def __init__(self, config): method forward (line 358) | def forward(self, x): function repeat_kv (line 393) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention (line 407) | class LlamaAttention(nn.Module): method __init__ (line 410) | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): method _init_rope (line 455) | def _init_rope(self): method forward (line 482) | def forward( class LlamaFlashAttention2 (line 603) | class LlamaFlashAttention2(LlamaAttention): method __init__ (line 610) | def __init__(self, *args, **kwargs): method forward (line 618) | def forward( method _flash_attention_forward (line 715) | def _flash_attention_forward( method _upad_input (line 795) | def _upad_input( class LlamaSdpaAttention (line 841) | class LlamaSdpaAttention(LlamaAttention): method forward (line 849) | def forward( class LlamaInfiniAttention (line 939) | class LlamaInfiniAttention(LlamaAttention): method __init__ (line 940) | def __init__( method forward (line 952) | def forward( method _retrieve_from_memory (line 1097) | def _retrieve_from_memory(self, query_states, memory, norm_term): method _update_memory (line 1133) | def _update_memory(self, key_states, value_states, memory, norm_term): class LlamaDecoderLayer (line 1167) | class LlamaDecoderLayer(nn.Module): method __init__ (line 1168) | def __init__(self, config: LlamaConfig, layer_idx: int): method forward (line 1182) | def forward( class LlamaPreTrainedModel (line 1284) | class LlamaPreTrainedModel(PreTrainedModel): method _init_weights (line 1297) | def _init_weights(self, module): method _setup_cache (line 1308) | def _setup_cache( method _reset_cache (line 1339) | def _reset_cache(self): class LlamaModel (line 1423) | class LlamaModel(LlamaPreTrainedModel): method __init__ (line 1431) | def __init__(self, config: LlamaConfig): method get_input_embeddings (line 1451) | def get_input_embeddings(self): method set_input_embeddings (line 1454) | def set_input_embeddings(self, value): method forward (line 1459) | def forward( method _update_causal_mask (line 1619) | def _update_causal_mask( class LlamaForCausalLM (line 1699) | class LlamaForCausalLM(LlamaPreTrainedModel): method __init__ (line 1702) | def __init__(self, config): method get_input_embeddings (line 1711) | def get_input_embeddings(self): method set_input_embeddings (line 1714) | def set_input_embeddings(self, value): method get_output_embeddings (line 1717) | def get_output_embeddings(self): method set_output_embeddings (line 1720) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 1723) | def set_decoder(self, decoder): method get_decoder (line 1726) | def get_decoder(self): method forward (line 1734) | def forward( method prepare_inputs_for_generation (line 1839) | def prepare_inputs_for_generation( method _reorder_cache (line 1947) | def _reorder_cache(past_key_values, beam_idx): FILE: modeling_gemma.py function debug_print (line 56) | def debug_print(*args): function _get_unpad_data (line 80) | def _get_unpad_data(attention_mask): class GemmaRMSNorm (line 94) | class GemmaRMSNorm(nn.Module): method __init__ (line 95) | def __init__(self, dim: int, eps: float = 1e-6): method _norm (line 100) | def _norm(self, x): method forward (line 103) | def forward(self, x): class GemmaRotaryEmbedding (line 114) | class GemmaRotaryEmbedding(nn.Module): method __init__ (line 115) | def __init__(self, dim, max_position_embeddings=2048, base=10000, devi... method forward (line 124) | def forward(self, x, position_ids, seq_len=None): function rotate_half (line 159) | def rotate_half(x): function apply_rotary_pos_emb (line 167) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di... class GemmaMLP (line 194) | class GemmaMLP(nn.Module): method __init__ (line 195) | def __init__(self, config): method forward (line 216) | def forward(self, x): function repeat_kv (line 221) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class GemmaAttention (line 235) | class GemmaAttention(nn.Module): method __init__ (line 239) | def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): method forward (line 288) | def forward( class GemmaFlashAttention2 (line 366) | class GemmaFlashAttention2(GemmaAttention): method __init__ (line 373) | def __init__(self, *args, **kwargs): method forward (line 382) | def forward( method _flash_attention_forward (line 479) | def _flash_attention_forward( method _upad_input (line 559) | def _upad_input( class GemmaSdpaAttention (line 606) | class GemmaSdpaAttention(GemmaAttention): method forward (line 614) | def forward( class GemmaInfiniAttention (line 700) | class GemmaInfiniAttention(GemmaAttention): method __init__ (line 701) | def __init__( method forward (line 713) | def forward( method _retrieve_from_memory (line 835) | def _retrieve_from_memory(self, query_states): method _update_memory (line 866) | def _update_memory(self, key_states, value_states): class GemmaDecoderLayer (line 900) | class GemmaDecoderLayer(nn.Module): method __init__ (line 901) | def __init__(self, config: GemmaConfig, layer_idx: int): method forward (line 915) | def forward( class GemmaPreTrainedModel (line 1002) | class GemmaPreTrainedModel(PreTrainedModel): method _init_weights (line 1013) | def _init_weights(self, module): method _setup_cache (line 1024) | def _setup_cache( method _reset_cache (line 1046) | def _reset_cache(self): class GemmaModel (line 1130) | class GemmaModel(GemmaPreTrainedModel): method __init__ (line 1138) | def __init__(self, config: GemmaConfig): method get_input_embeddings (line 1158) | def get_input_embeddings(self): method set_input_embeddings (line 1161) | def set_input_embeddings(self, value): method forward (line 1166) | def forward( method _update_causal_mask (line 1311) | def _update_causal_mask( class GemmaForCausalLM (line 1391) | class GemmaForCausalLM(GemmaPreTrainedModel): method __init__ (line 1394) | def __init__(self, config): method get_input_embeddings (line 1403) | def get_input_embeddings(self): method set_input_embeddings (line 1406) | def set_input_embeddings(self, value): method get_output_embeddings (line 1409) | def get_output_embeddings(self): method set_output_embeddings (line 1412) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 1415) | def set_decoder(self, decoder): method get_decoder (line 1418) | def get_decoder(self): method forward (line 1426) | def forward( method prepare_inputs_for_generation (line 1521) | def prepare_inputs_for_generation( method _reorder_cache (line 1629) | def _reorder_cache(past_key_values, beam_idx): class GemmaForSequenceClassification (line 1657) | class GemmaForSequenceClassification(GemmaPreTrainedModel): method __init__ (line 1658) | def __init__(self, config): method get_input_embeddings (line 1667) | def get_input_embeddings(self): method set_input_embeddings (line 1670) | def set_input_embeddings(self, value): method forward (line 1674) | def forward( FILE: original_llama.py function _get_unpad_data (line 64) | def _get_unpad_data(attention_mask): class LlamaRMSNorm (line 76) | class LlamaRMSNorm(nn.Module): method __init__ (line 77) | def __init__(self, hidden_size, eps=1e-6): method forward (line 85) | def forward(self, hidden_states): class LlamaRotaryEmbedding (line 96) | class LlamaRotaryEmbedding(nn.Module): method __init__ (line 97) | def __init__( method sin_cached (line 135) | def sin_cached(self): method cos_cached (line 143) | def cos_cached(self): method forward (line 151) | def forward(self, x, position_ids): class LlamaLinearScalingRotaryEmbedding (line 175) | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): method forward (line 178) | def forward(self, x, position_ids): class LlamaDynamicNTKScalingRotaryEmbedding (line 185) | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): method forward (line 188) | def forward(self, x, position_ids): function rotate_half (line 211) | def rotate_half(x): function apply_rotary_pos_emb (line 218) | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_di... class LlamaMLP (line 245) | class LlamaMLP(nn.Module): method __init__ (line 246) | def __init__(self, config): method forward (line 256) | def forward(self, x): function repeat_kv (line 290) | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: class LlamaAttention (line 304) | class LlamaAttention(nn.Module): method __init__ (line 307) | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): method _init_rope (line 352) | def _init_rope(self): method forward (line 379) | def forward( class LlamaFlashAttention2 (line 500) | class LlamaFlashAttention2(LlamaAttention): method __init__ (line 507) | def __init__(self, *args, **kwargs): method forward (line 515) | def forward( method _flash_attention_forward (line 612) | def _flash_attention_forward( method _upad_input (line 692) | def _upad_input( class LlamaSdpaAttention (line 738) | class LlamaSdpaAttention(LlamaAttention): method forward (line 746) | def forward( class LlamaDecoderLayer (line 843) | class LlamaDecoderLayer(nn.Module): method __init__ (line 844) | def __init__(self, config: LlamaConfig, layer_idx: int): method forward (line 858) | def forward( class LlamaPreTrainedModel (line 945) | class LlamaPreTrainedModel(PreTrainedModel): method _init_weights (line 955) | def _init_weights(self, module): method _setup_cache (line 966) | def _setup_cache( method _reset_cache (line 988) | def _reset_cache(self): class LlamaModel (line 1071) | class LlamaModel(LlamaPreTrainedModel): method __init__ (line 1079) | def __init__(self, config: LlamaConfig): method get_input_embeddings (line 1099) | def get_input_embeddings(self): method set_input_embeddings (line 1102) | def set_input_embeddings(self, value): method forward (line 1106) | def forward( method _update_causal_mask (line 1240) | def _update_causal_mask( class LlamaForCausalLM (line 1338) | class LlamaForCausalLM(LlamaPreTrainedModel): method __init__ (line 1341) | def __init__(self, config): method get_input_embeddings (line 1350) | def get_input_embeddings(self): method set_input_embeddings (line 1353) | def set_input_embeddings(self, value): method get_output_embeddings (line 1356) | def get_output_embeddings(self): method set_output_embeddings (line 1359) | def set_output_embeddings(self, new_embeddings): method set_decoder (line 1362) | def set_decoder(self, decoder): method get_decoder (line 1365) | def get_decoder(self): method forward (line 1372) | def forward( method prepare_inputs_for_generation (line 1478) | def prepare_inputs_for_generation( method _reorder_cache (line 1586) | def _reorder_cache(past_key_values, beam_idx): class LlamaForSequenceClassification (line 1613) | class LlamaForSequenceClassification(LlamaPreTrainedModel): method __init__ (line 1614) | def __init__(self, config): method get_input_embeddings (line 1623) | def get_input_embeddings(self): method set_input_embeddings (line 1626) | def set_input_embeddings(self, value): method forward (line 1630) | def forward( class LlamaForQuestionAnswering (line 1740) | class LlamaForQuestionAnswering(LlamaPreTrainedModel): method __init__ (line 1744) | def __init__(self, config): method get_input_embeddings (line 1752) | def get_input_embeddings(self): method set_input_embeddings (line 1755) | def set_input_embeddings(self, value): method forward (line 1759) | def forward( FILE: test_train.small.gemma.infini.py function tokenize_function (line 76) | def tokenize_function(examples): function group_texts (line 93) | def group_texts(examples): FILE: test_train.small.gemma.py function tokenize_function (line 78) | def tokenize_function(examples): function group_texts (line 94) | def group_texts(examples): FILE: train.gemma.infini.noclm.py function parse_args (line 75) | def parse_args(): function main (line 304) | def main(): FILE: train.llama.infini.noclm.py function parse_args (line 76) | def parse_args(): function main (line 305) | def main():