SYMBOL INDEX (96 symbols across 4 files) FILE: checkpoint.py function copy_to_shm (line 43) | def copy_to_shm(file: str): function copy_from_shm (line 60) | def copy_from_shm(file: str): function fast_unpickle (line 71) | def fast_unpickle(path: str) -> Any: function fast_pickle (line 77) | def fast_pickle(obj: Any, path: str) -> None: function load_tensors (line 83) | def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=N... function path_tuple_to_string (line 110) | def path_tuple_to_string(path: tuple) -> str: function get_load_path_str (line 122) | def get_load_path_str( function replace_with_load_state (line 144) | def replace_with_load_state( function restore (line 180) | def restore( FILE: model.py class QuantizedWeight8bit (line 38) | class QuantizedWeight8bit: method shape (line 43) | def shape(self): class TrainingState (line 54) | class TrainingState(NamedTuple): function _match (line 60) | def _match(qs, ks): function with_sharding_constraint (line 71) | def with_sharding_constraint(x, constraint): function cast_bfloat16 (line 78) | def cast_bfloat16(x): function ffn_size (line 85) | def ffn_size(emb_size, widening_factor): function apply_rules (line 92) | def apply_rules(rules): class KVMemory (line 178) | class KVMemory(NamedTuple): function init_layer_memories (line 184) | def init_layer_memories( class Memory (line 203) | class Memory(NamedTuple): class Router (line 208) | class Router(hk.Module): method __init__ (line 209) | def __init__( method compute_routing_prob (line 225) | def compute_routing_prob( method _compute_routing_prob (line 231) | def _compute_routing_prob( method _router_weights (line 251) | def _router_weights( class MoELayer (line 272) | class MoELayer(hk.Module): method __init__ (line 273) | def __init__( method _inference_call (line 294) | def _inference_call(self, inputs: jax.Array, padding_mask: Optional[ja... method __call__ (line 399) | def __call__(self, inputs: jax.Array, padding_mask: jax.Array): class MHAOutput (line 403) | class MHAOutput(NamedTuple): class DecoderOutput (line 410) | class DecoderOutput(NamedTuple): class TransformerOutput (line 415) | class TransformerOutput(NamedTuple): class TransformerConfig (line 421) | class TransformerConfig: method __post_init__ (line 445) | def __post_init__(self): method partition_rules (line 451) | def partition_rules(self): method make (line 454) | def make(self, mesh=None) -> "Transformer": method get_memory_sharding (line 476) | def get_memory_sharding(self): function hk_rms_norm (line 489) | def hk_rms_norm( function make_attention_mask (line 499) | def make_attention_mask( class Linear (line 525) | class Linear(hk.Linear): method __init__ (line 526) | def __init__( method __call__ (line 544) | def __call__( class RMSNorm (line 587) | class RMSNorm(hk.RMSNorm): method __init__ (line 589) | def __init__( method __call__ (line 600) | def __call__(self, inputs: jax.Array): function rotate_half (line 627) | def rotate_half( class RotaryEmbedding (line 635) | class RotaryEmbedding(hk.Module): method __init__ (line 644) | def __init__( method __call__ (line 655) | def __call__( class MultiHeadAttention (line 694) | class MultiHeadAttention(hk.Module): method __init__ (line 695) | def __init__( method __call__ (line 720) | def __call__( method _linear_projection (line 894) | def _linear_projection( class MHABlock (line 915) | class MHABlock(hk.Module): method __call__ (line 927) | def __call__( class DenseBlock (line 964) | class DenseBlock(hk.Module): method __call__ (line 973) | def __call__( class DecoderLayer (line 1011) | class DecoderLayer(hk.Module): method __call__ (line 1030) | def __call__( class LanguageModelOutput (line 1105) | class LanguageModelOutput(NamedTuple): class InOutEmbed (line 1110) | class InOutEmbed(hk.Embed): method __init__ (line 1113) | def __init__( method embeddings (line 1128) | def embeddings(self): method decode (line 1139) | def decode( class LanguageModelConfig (line 1147) | class LanguageModelConfig: method initialize (line 1167) | def initialize(self): method make (line 1179) | def make(self, *args, **kwargs): method partition_rules (line 1193) | def partition_rules(self): function layer_norm (line 1197) | def layer_norm(x, model): class LanguageModel (line 1202) | class LanguageModel(hk.Module): method __call__ (line 1211) | def __call__( method init_memory (line 1281) | def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16): method prefill_memory (line 1284) | def prefill_memory(self, prompts, memory): class Transformer (line 1292) | class Transformer(hk.Module): method init_memory (line 1313) | def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bf... method __call__ (line 1326) | def __call__( FILE: run.py function main (line 24) | def main(): FILE: runners.py class SampleSettings (line 50) | class SampleSettings(NamedTuple): class SampleOutput (line 58) | class SampleOutput(NamedTuple): function insert_slice (line 65) | def insert_slice(memory: Memory, slice, length, i): function pad_to_size (line 77) | def pad_to_size(x, size): function top_p_filter (line 84) | def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: function sample_token (line 100) | def sample_token( class ModelRunner (line 137) | class ModelRunner: method make_forward_fn (line 150) | def make_forward_fn(self, mesh: Any): method initialize (line 159) | def initialize( method init (line 193) | def init(self, rng: jax.Array, data) -> TrainingState: method get_state_sharding (line 199) | def get_state_sharding(self, init_data): method load_or_init (line 212) | def load_or_init( class Request (line 253) | class Request: class InferenceRunner (line 262) | class InferenceRunner: method get_pad_bucket (line 271) | def get_pad_bucket(self, size): method initialize (line 275) | def initialize(self): method run (line 442) | def run(self): function make_mesh (line 580) | def make_mesh( function sample_from_model (line 596) | def sample_from_model(server, prompt, max_len, temperature):