SYMBOL INDEX (210 symbols across 9 files) FILE: lwm/data.py class DatasetFactory (line 16) | class DatasetFactory(object): method get_default_config (line 20) | def get_default_config(updates=None): method load_dataset (line 35) | def load_dataset(cls, config, tokenizer, **kwargs): method __init__ (line 51) | def __init__(self): class TextProcessor (line 55) | class TextProcessor(object): method get_default_config (line 58) | def get_default_config(updates=None): method __init__ (line 70) | def __init__(self, config, tokenizer): method __call__ (line 77) | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos... class VisionTextProcessor (line 126) | class VisionTextProcessor(object): method get_default_config (line 128) | def get_default_config(updates=None): method __init__ (line 144) | def __init__(self, config, tokenizer): method __call__ (line 153) | def __call__(self, example, has_aux=False, add_bos_token=True, add_eos... class HuggingfaceDataset (line 242) | class HuggingfaceDataset(object): method get_default_config (line 248) | def get_default_config(updates=None): method __init__ (line 262) | def __init__(self, config, tokenizer, text_processor): method __iter__ (line 272) | def __iter__(self): method get_state_dict (line 305) | def get_state_dict(self): method load_state_dict (line 308) | def load_state_dict(self, state_dict): method seq_length (line 313) | def seq_length(self): method tokenizer (line 317) | def tokenizer(self): method text_processor (line 321) | def text_processor(self): method dataset (line 325) | def dataset(self): method vocab_size (line 329) | def vocab_size(self): class JsonDataset (line 333) | class JsonDataset(object): method get_default_config (line 339) | def get_default_config(updates=None): method __init__ (line 360) | def __init__(self, config, tokenizer, text_processor, node_info): method parse_json (line 370) | def parse_json(self, line): method json_iterator (line 380) | def json_iterator(self): method batched (line 398) | def batched(self, iterator, batch_size): method parallel_example_iterator (line 408) | def parallel_example_iterator(self): method __iter__ (line 434) | def __iter__(self): method _make_callback (line 510) | def _make_callback(self, v): method get_state_dict (line 513) | def get_state_dict(self): method load_state_dict (line 521) | def load_state_dict(self, state_dict): method seq_length (line 529) | def seq_length(self): method tokenizer (line 533) | def tokenizer(self): method text_processor (line 537) | def text_processor(self): method vocab_size (line 541) | def vocab_size(self): class JsonVisionDataset (line 545) | class JsonVisionDataset(object): method get_default_config (line 547) | def get_default_config(updates=None): method __init__ (line 568) | def __init__(self, config, tokenizer, text_processor, node_info): method parse_json (line 578) | def parse_json(self, line): method json_iterator (line 588) | def json_iterator(self): method batched (line 606) | def batched(self, iterator, batch_size): method parallel_example_iterator (line 616) | def parallel_example_iterator(self): method __iter__ (line 642) | def __iter__(self): method _iter_pad (line 651) | def _iter_pad(self): method _iter_no_pad (line 736) | def _iter_no_pad(self): method _make_callback (line 810) | def _make_callback(self, v): method get_state_dict (line 813) | def get_state_dict(self): method load_state_dict (line 821) | def load_state_dict(self, state_dict): method seq_length (line 829) | def seq_length(self): method tokenizer (line 833) | def tokenizer(self): method text_processor (line 837) | def text_processor(self): method vocab_size (line 841) | def vocab_size(self): FILE: lwm/llama.py class LLaMAConfig (line 133) | class LLaMAConfig(PretrainedConfig): method __init__ (line 136) | def __init__( method get_default_config (line 193) | def get_default_config(cls, updates=None): method get_jax_mesh (line 202) | def get_jax_mesh(axis_dims): method get_ranks_and_size (line 206) | def get_ranks_and_size(mesh): method get_partition_rules (line 222) | def get_partition_rules(scan_layers=False, scan_axis=0): method get_weight_decay_exclusions (line 286) | def get_weight_decay_exclusions(): method get_frozen_param_exclusions (line 290) | def get_frozen_param_exclusions(freeze_base): method rng_keys (line 297) | def rng_keys(): method load_config (line 301) | def load_config(cls, path): class RMSNorm (line 320) | class RMSNorm(nn.Module): method setup (line 326) | def setup(self) -> None: method _norm (line 334) | def _norm(self, x: jnp.ndarray) -> jnp.ndarray: method __call__ (line 337) | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: function precompute_freqs_cis (line 344) | def precompute_freqs_cis(dim: int, max_position_embedding: int, theta: f... function apply_rotary_emb (line 353) | def apply_rotary_emb( class FlaxLLaMAAttention (line 378) | class FlaxLLaMAAttention(nn.Module): method setup (line 384) | def setup(self): method _split_heads (line 434) | def _split_heads(self, hidden_states): method _merge_heads (line 437) | def _merge_heads(self, hidden_states): method _concatenate_to_cache (line 441) | def _concatenate_to_cache(self, key, value, query, attention_mask): method __call__ (line 494) | def __call__( class FlaxLLaMAMLP (line 623) | class FlaxLLaMAMLP(nn.Module): method setup (line 629) | def setup(self) -> None: method __call__ (line 658) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.... class FlaxLLaMABlock (line 664) | class FlaxLLaMABlock(nn.Module): method setup (line 670) | def setup(self) -> None: method __call__ (line 704) | def __call__( class FlaxLLaMAPreTrainedModel (line 747) | class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel): method __init__ (line 757) | def __init__( method init_weights (line 769) | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, pa... method init_cache (line 806) | def init_cache(self, batch_size, max_length): method __call__ (line 827) | def __call__( class FlaxLLaMABlockCollection (line 898) | class FlaxLLaMABlockCollection(nn.Module): method __call__ (line 905) | def __call__( class FlaxLLaMAModule (line 982) | class FlaxLLaMAModule(nn.Module): method setup (line 988) | def setup(self): method __call__ (line 1002) | def __call__( class FlaxLLaMAForCausalLMModule (line 1049) | class FlaxLLaMAForCausalLMModule(nn.Module): method setup (line 1055) | def setup(self): method __call__ (line 1066) | def __call__( class FlaxLLaMAForCausalLM (line 1110) | class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel): method prepare_inputs_for_generation (line 1113) | def prepare_inputs_for_generation( method update_inputs_for_generation (line 1134) | def update_inputs_for_generation(self, model_outputs, model_kwargs): FILE: lwm/train.py function main (line 59) | def main(argv): FILE: lwm/vision_chat.py class Sampler (line 40) | class Sampler: method __init__ (line 41) | def __init__(self): method block_size (line 52) | def block_size(self): method data_dim (line 56) | def data_dim(self): method _process_frame (line 59) | def _process_frame(self, image, size): method _read_process_vision (line 76) | def _read_process_vision(self, path, max_n_frames): method construct_input (line 110) | def construct_input(self, prompts, max_n_frames): method _load_model (line 148) | def _load_model(self): method _forward_generate (line 197) | def _forward_generate(self): method __call__ (line 222) | def __call__(self, prompts, max_n_frames): function main (line 236) | def main(argv): FILE: lwm/vision_generation.py function main (line 44) | def main(argv): FILE: lwm/vision_llama.py class VideoLLaMAConfig (line 27) | class VideoLLaMAConfig(LLaMAConfig): method __init__ (line 30) | def __init__(self, vision_vocab_size=8448, tie_vision_embeddings=False... method get_partition_rules (line 37) | def get_partition_rules(scan_layers=False, scan_axis=0): method load_config (line 107) | def load_config(cls, path): class FlaxVideoLLaMAPreTrainedModel (line 121) | class FlaxVideoLLaMAPreTrainedModel(FlaxPreTrainedModel): method __init__ (line 131) | def __init__( method init_cache (line 143) | def init_cache(self, batch_size, max_length): method init_weights (line 156) | def init_weights(self, rng, input_shape, params=None): method __call__ (line 179) | def __call__( class FlaxVideoLLaMAModule (line 255) | class FlaxVideoLLaMAModule(nn.Module): method setup (line 261) | def setup(self): method __call__ (line 283) | def __call__( class FlaxVideoLLaMAForCausalLMModule (line 346) | class FlaxVideoLLaMAForCausalLMModule(nn.Module): method setup (line 352) | def setup(self): method __call__ (line 371) | def __call__( class FlaxVideoLLaMAForCausalLM (line 444) | class FlaxVideoLLaMAForCausalLM(FlaxVideoLLaMAPreTrainedModel): method prepare_inputs_for_generation (line 447) | def prepare_inputs_for_generation( method update_inputs_for_generation (line 468) | def update_inputs_for_generation(self, model_outputs, model_kwargs): method _sample_vision (line 476) | def _sample_vision( method generate_vision (line 583) | def generate_vision( FILE: lwm/vqgan.py class VQGAN (line 14) | class VQGAN: method __init__ (line 15) | def __init__(self, vqgan_checkpoint, replicate=False): method _wrap_fn (line 26) | def _wrap_fn(self, fn): method _encode (line 33) | def _encode(self): method _decode (line 43) | def _decode(self): method encode (line 52) | def encode(self, pixel_values): method decode (line 55) | def decode(self, encoding): class VQGANConfig (line 59) | class VQGANConfig(PretrainedConfig): method __init__ (line 62) | def __init__( method get_default_config (line 93) | def get_default_config(cls, updates=None): method load_config (line 101) | def load_config(cls, path): class VQGANModel (line 105) | class VQGANModel(nn.Module): method setup (line 108) | def setup(self): method encode (line 117) | def encode(self, pixel_values): method decode (line 130) | def decode(self, encoding, is_codebook_indices=True): method __call__ (line 143) | def __call__(self, pixel_values): class Encoder (line 149) | class Encoder(nn.Module): method __call__ (line 153) | def __call__(self, pixel_values): class Decoder (line 167) | class Decoder(nn.Module): method __call__ (line 171) | def __call__(self, hidden_states): class VectorQuantizer (line 187) | class VectorQuantizer(nn.Module): method __call__ (line 192) | def __call__(self, z, encoding_indices=None): class DownsamplingBlock (line 224) | class DownsamplingBlock(nn.Module): method __call__ (line 229) | def __call__(self, hidden_states): class ResnetBlock (line 242) | class ResnetBlock(nn.Module): method __call__ (line 248) | def __call__(self, hidden_states): class AttnBlock (line 266) | class AttnBlock(nn.Module): method __call__ (line 268) | def __call__(self, hidden_states): class Downsample (line 286) | class Downsample(nn.Module): method __call__ (line 290) | def __call__(self, hidden_states): class Upsample (line 306) | class Upsample(nn.Module): method __call__ (line 310) | def __call__(self, hidden_states): class UpsamplingBlock (line 322) | class UpsamplingBlock(nn.Module): method __call__ (line 327) | def __call__(self, hidden_states): class MidBlock (line 340) | class MidBlock(nn.Module): method __call__ (line 346) | def __call__(self, hidden_states): FILE: scripts/eval_needle.py class LLMNeedleHaystackTester (line 47) | class LLMNeedleHaystackTester: method __init__ (line 64) | def __init__(self, method generate_random_number (line 109) | def generate_random_number(self, num_digits): method logistic (line 114) | def logistic(self, x, L=100, x0=50, k=.1): method read_context_files (line 121) | def read_context_files(self, n): method encode_and_trim (line 135) | def encode_and_trim(self, context, context_length): method create_contexts (line 141) | def create_contexts(self, needle_rnd_number, insert_needle, random_cit... method insert_needle (line 162) | def insert_needle(self, needle, context, depth_percent, context_length): method generate_context (line 199) | def generate_context(self, needle, trim_context, context_length, depth... method compute_max_input_length (line 203) | def compute_max_input_length(self, context_length, buffer=1024): method run_test (line 209) | def run_test(self): method print_start_test_summary (line 295) | def print_start_test_summary(self): method start_test (line 303) | def start_test(self): class Sampler (line 310) | class Sampler: method __init__ (line 311) | def __init__(self): method block_size (line 319) | def block_size(self): method data_dim (line 324) | def data_dim(self): method _load_model (line 327) | def _load_model(self): method _forward_generate (line 375) | def _forward_generate(self): method __call__ (line 402) | def __call__(self, prompts, max_input_length): function main (line 427) | def main(argv): FILE: scripts/eval_needle_multi.py class LLMNeedleHaystackTester (line 50) | class LLMNeedleHaystackTester: method __init__ (line 67) | def __init__(self, method generate_random_number (line 114) | def generate_random_number(self, num_digits): method logistic (line 119) | def logistic(self, x, L=100, x0=50, k=.1): method read_context_files (line 126) | def read_context_files(self, n): method encode_and_trim (line 137) | def encode_and_trim(self, context, context_length): method create_contexts (line 143) | def create_contexts(self, needles_info, random_cities_retrieve, contex... method insert_needle (line 166) | def insert_needle(self, needle, context, depth_percent, context_length): method generate_context (line 203) | def generate_context(self, needle, trim_context, context_length, depth... method compute_max_input_length (line 207) | def compute_max_input_length(self, context_length, buffer=1024): method run_test (line 214) | def run_test(self): method print_start_test_summary (line 304) | def print_start_test_summary(self): method start_test (line 312) | def start_test(self): class Sampler (line 319) | class Sampler: method __init__ (line 320) | def __init__(self): method block_size (line 328) | def block_size(self): method data_dim (line 333) | def data_dim(self): method _load_model (line 336) | def _load_model(self): method _forward_generate (line 384) | def _forward_generate(self): method __call__ (line 411) | def __call__(self, prompts, max_input_length): function main (line 436) | def main(argv):