SYMBOL INDEX (104 symbols across 17 files) FILE: configs.py function fetch_model_params (line 12) | def fetch_model_params(model): FILE: data/create_tfrecords.py function wikitext_detokenizer (line 45) | def wikitext_detokenizer(string): function _int64_feature (line 79) | def _int64_feature(value): function write_to_file (line 86) | def write_to_file(writer, data): function get_tokenizer (line 97) | def get_tokenizer(args): function split_list (line 104) | def split_list(l, n): function archive_to_tokens (line 109) | def archive_to_tokens(f, encoder, args, prefix=[]): function write_files (line 123) | def write_files(files, files_per, output_dir, out_name, start_no, write_... function get_files (line 150) | def get_files(input_dir, filetypes=None): function read_checkpoint (line 163) | def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): function create_tfrecords (line 175) | def create_tfrecords(params, write_remainder=True, write_every_n_files=1... function create_tfrecords_mp (line 245) | def create_tfrecords_mp(files, args): FILE: data/encoders.py function fetch_encoder (line 4) | def fetch_encoder(params): function encode (line 24) | def encode(encoder, text): FILE: encoders.py function fetch_encoder (line 4) | def fetch_encoder(params): function encode (line 24) | def encode(encoder, text, gpt=True): FILE: export.py function export_model (line 3) | def export_model(estimator, export_dir, params, FILE: inputs.py function _get_number_of_documents (line 14) | def _get_number_of_documents(filename): function _get_number_of_documents_by_iteration (line 21) | def _get_number_of_documents_by_iteration(filename): function _get_skip_index (line 32) | def _get_skip_index(all_files, n_batches): function _parse_function (line 55) | def _parse_function(example_proto): function autoregressive_sample_text (line 63) | def autoregressive_sample_text(params, x): function sequential_input (line 74) | def sequential_input(params, global_step=None, eval=False): function pred_input (line 139) | def pred_input(params, logger, enc=None, function handle_pred_output (line 163) | def handle_pred_output(predictions, logger, enc, params, out_name="test"): function generic_text (line 188) | def generic_text(params, eval=False, sample_text_fn=None, **kwargs): function text_dataset (line 224) | def text_dataset(files, params, stitch, datatype, batch=True, sample_tex... function autoregressive_sample_text_random_documents (line 297) | def autoregressive_sample_text_random_documents(params, x): function mlm_sample_text (line 316) | def mlm_sample_text(params, x, random_documents=False): FILE: main.py function parse_args (line 21) | def parse_args(): function main (line 51) | def main(args): FILE: model_fns.py function model_fn (line 15) | def model_fn(features, labels, mode, params): FILE: models/activations.py function _arcsinh (line 20) | def _arcsinh(x): function _var (line 24) | def _var(x, init): function _pos_var (line 29) | def _pos_var(x, val): function _rrelu (line 33) | def _rrelu(x): function _elish (line 38) | def _elish(x): function get_activation_fn (line 79) | def get_activation_fn(params): FILE: models/gpt2/gpt2.py function block (line 12) | def block(params, scope, layer_num, bias, sequence_dim, memory_length_di... function model (line 99) | def model(mtf_features, other_features, params, mesh, variable_dtype, co... FILE: models/layers.py function exists (line 15) | def exists(x): function identity (line 19) | def identity(x, *args, **kwargs): function is_incremental_inference (line 23) | def is_incremental_inference(context): function norm (line 27) | def norm(x, axis, epsilon=1e-8): function rezero (line 33) | def rezero(x, scope, dtype): function scale_norm (line 39) | def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5,... function layer_norm (line 54) | def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5,... function linear_attention (line 76) | def linear_attention(q, k, v): function causal_linear_attention (line 91) | def causal_linear_attention(q, k, v, eps = 1e-6): function linear (line 111) | def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=No... function memory_key_values (line 127) | def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_d... function attn (line 156) | def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, me... function mlp (line 277) | def mlp(x, scope, n_state, *, variable_dtype, params): function mlp_glu (line 288) | def mlp_glu(x, scope, n_state, *, variable_dtype, params): function axial_positional_emb (line 303) | def axial_positional_emb(embd_dim, mesh, params, variable_dtype): function rotary_positional_emb (line 330) | def rotary_positional_emb(mesh, sequence_dim, params, variable_dtype): function rotate_half (line 347) | def rotate_half(x): function apply_rotary_emb (line 355) | def apply_rotary_emb(x, cos, sin): FILE: models/utils.py function entmax_backward (line 6) | def entmax_backward(explicit_inputs, all_inputs, forward_operations, out... function entmax_forward (line 21) | def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): function entmax (line 55) | def entmax(x, alpha=1.3, dim=None, n_iter=50): function entmax_cross_entropy_with_logits (line 65) | def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=... function sample_categorical (line 90) | def sample_categorical(x, dim=None): function biasmask_attn_weights (line 99) | def biasmask_attn_weights(mesh, nd, ns, variable_dtype): function parse_inputs (line 113) | def parse_inputs(mtf_features, other_features): FILE: optimizers.py function clip_by_global_norm (line 9) | def clip_by_global_norm(grads, clip_norm): function get_optimizer (line 16) | def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): class AdamWeightDecayOptimizer (line 95) | class AdamWeightDecayOptimizer(mtf.optimize.Optimizer): method __init__ (line 98) | def __init__(self, method apply_grad (line 116) | def apply_grad(self, grad, var): method _do_use_weight_decay (line 168) | def _do_use_weight_decay(self, param_name): FILE: run_experiment.py function get_open_port (line 44) | def get_open_port(lo=8000, hi=8100): function train_thread (line 51) | def train_thread(args, tpu, id, q): function get_json (line 111) | def get_json(uri, params=None, timeout=15): function get_tag_sets (line 117) | def get_tag_sets(base_uri): function get_scalar_data (line 126) | def get_scalar_data(base_uri, run, tag): function get_run_data (line 132) | def get_run_data(port): function main (line 159) | def main(_run): function goodbye (line 247) | def goodbye(id): FILE: sample.py function sample_autoregressive (line 8) | def sample_autoregressive(partial_sequences, FILE: tasks.py function lambada_create_tokens_data (line 22) | def lambada_create_tokens_data(params, path): function lambada_read_or_create_tokens_data (line 34) | def lambada_read_or_create_tokens_data(params, path): function bin_pack (line 42) | def bin_pack(params, tokens_data): function lambada_init (line 61) | def lambada_init(params): function lambada_get_task_info (line 77) | def lambada_get_task_info(params): function lambada_input (line 84) | def lambada_input(params): FILE: utils.py function setup_logging (line 15) | def setup_logging(args): function get_batch_size (line 29) | def get_batch_size(params): function add_mode_to_params (line 33) | def add_mode_to_params(params, mode): function simd_mesh_setup (line 45) | def simd_mesh_setup(params, mesh_shape, layout_rules): function remove_batch_from_layout (line 67) | def remove_batch_from_layout(layout): function yes_or_no (line 85) | def yes_or_no(question): function remove_gs_or_filepath (line 94) | def remove_gs_or_filepath(path): function save_config (line 102) | def save_config(params_dict, logdir): function expand_attention_types_params (line 132) | def expand_attention_types_params(params_list): function get_n_trainable_vars (line 140) | def get_n_trainable_vars(graph): function print_dim_names (line 157) | def print_dim_names(graph): function get_graph_info (line 177) | def get_graph_info(graph): function loss_denominator (line 189) | def loss_denominator(targets, num_microbatches): function check_dataset (line 206) | def check_dataset(input_fn, params, global_step=None): function auto_layout (line 224) | def auto_layout(graph, mesh_shape, logits, loss): function auto_layout_and_mesh_shape (line 229) | def auto_layout_and_mesh_shape(graph, num_cores, logits, loss): function create_host_call (line 236) | def create_host_call(model_dir): function natural_sort (line 289) | def natural_sort(l):