SYMBOL INDEX (69 symbols across 2 files) FILE: block_recurrent_transformer_pytorch/block_recurrent_transformer_pytorch.py function exists (line 22) | def exists(val): function default (line 25) | def default(val, d): function is_empty (line 28) | def is_empty(t: torch.Tensor): function cast_tuple (line 31) | def cast_tuple(t, length = 1): function all_unique (line 34) | def all_unique(arr): function eval_decorator (line 37) | def eval_decorator(fn): function once (line 46) | def once(fn): function compact (line 59) | def compact(arr): function and_reduce (line 62) | def and_reduce(arr: List[torch.Tensor]): function safe_cat (line 70) | def safe_cat(*args, dim = 1): function divisible_by (line 78) | def divisible_by(numer, denom): function l2norm (line 81) | def l2norm(t): function pack_one (line 84) | def pack_one(t, pattern): function unpack_one (line 87) | def unpack_one(t, ps, pattern): function pad_at_dim (line 90) | def pad_at_dim(t, pad, dim = -1, value = 0.): class LayerNorm (line 97) | class LayerNorm(nn.Module): method __init__ (line 98) | def __init__(self, dim): method forward (line 103) | def forward(self, x): function log (line 108) | def log(t, eps = 1e-20): function gumbel_noise (line 111) | def gumbel_noise(t): function gumbel_sample (line 115) | def gumbel_sample(t, temperature = 1., dim = -1): function top_k (line 118) | def top_k(logits, thres = 0.9): class RotaryEmbedding (line 129) | class RotaryEmbedding(nn.Module): method __init__ (line 130) | def __init__( method device (line 151) | def device(self): method forward (line 154) | def forward(self): function rotate_half (line 174) | def rotate_half(x): function apply_rotary_pos_emb (line 179) | def apply_rotary_pos_emb(t, pos, scale = 1.): class MemoryManager (line 196) | class MemoryManager(nn.Module): method __init__ (line 197) | def __init__( method forward (line 240) | def forward( class StateContainer (line 307) | class StateContainer(nn.Module): method __init__ (line 308) | def __init__( method set_next_read_state (line 356) | def set_next_read_state( method read (line 365) | def read(self, x): method write (line 400) | def write( method forward (line 449) | def forward(self, x): class Attend (line 454) | class Attend(nn.Module): method __init__ (line 455) | def __init__( method get_mask (line 484) | def get_mask(self, n, device): method flash_attn (line 492) | def flash_attn(self, q, k, v, mask = None): method forward (line 536) | def forward(self, q, k, v, mask = None, use_flash_attn = None): class GEGLU (line 585) | class GEGLU(nn.Module): method forward (line 586) | def forward(self, x): function FeedForward (line 590) | def FeedForward(dim, mult = 4): class Attention (line 601) | class Attention(nn.Module): method __init__ (line 602) | def __init__( method forward (line 622) | def forward( class AttentionBlock (line 652) | class AttentionBlock(nn.Module): method __init__ (line 653) | def __init__( method device (line 703) | def device(self): method forward (line 706) | def forward( class BlockRecurrentTransformer (line 791) | class BlockRecurrentTransformer(nn.Module): method __init__ (line 792) | def __init__( method device (line 909) | def device(self): method get_causal_attn_mask (line 912) | def get_causal_attn_mask(self, width): method generate (line 926) | def generate( method forward (line 974) | def forward( class RecurrentTrainerWrapper (line 1108) | class RecurrentTrainerWrapper(nn.Module): method __init__ (line 1109) | def __init__( method generate (line 1124) | def generate( method forward (line 1171) | def forward( FILE: train.py function cycle (line 28) | def cycle(loader): function decode_token (line 33) | def decode_token(token): function decode_tokens (line 36) | def decode_tokens(tokens): class TextSamplerDataset (line 77) | class TextSamplerDataset(Dataset): method __init__ (line 78) | def __init__(self, data, seq_len): method __getitem__ (line 83) | def __getitem__(self, index): method __len__ (line 88) | def __len__(self):