SYMBOL INDEX (182 symbols across 16 files) FILE: app/gradio/app.py function infer (line 12) | def infer(prompt): FILE: app/gradio/backend.py class ServiceError (line 10) | class ServiceError(Exception): method __init__ (line 11) | def __init__(self, status_code): function get_images_from_backend (line 15) | def get_images_from_backend(prompt, backend_url): function get_model_version (line 27) | def get_model_version(url): FILE: app/streamlit/backend.py class ServiceError (line 10) | class ServiceError(Exception): method __init__ (line 11) | def __init__(self, status_code): function get_images_from_backend (line 15) | def get_images_from_backend(prompt, backend_url): function get_model_version (line 27) | def get_model_version(url): FILE: src/dalle_mini/data.py class Dataset (line 16) | class Dataset: method __post_init__ (line 45) | def __post_init__(self): method preprocess (line 129) | def preprocess(self, tokenizer, config): method dataloader (line 303) | def dataloader(self, split, batch_size, epoch=None): method length (line 370) | def length(self): function shift_tokens_right (line 388) | def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): function blank_caption_function (line 398) | def blank_caption_function(example, text_column, blank_caption_prob, rng... function normalize_function (line 408) | def normalize_function(example, text_column, text_normalizer): function filter_function (line 413) | def filter_function( function preprocess_function (line 430) | def preprocess_function( FILE: src/dalle_mini/model/configuration.py class DalleBartConfig (line 26) | class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig): method __init__ (line 34) | def __init__( FILE: src/dalle_mini/model/modeling.py function smelu (line 56) | def smelu(beta: Any = 1.0): function deepnet_init (line 82) | def deepnet_init(init_std, gain=1): class RMSNorm (line 117) | class RMSNorm(nn.Module): method __call__ (line 131) | def __call__(self, x): method _compute_rms_sq (line 150) | def _compute_rms_sq(self, x, axes): method _normalize (line 155) | def _normalize( function norm (line 189) | def norm(type, *args, **kwargs): function dot_product_attention_weights (line 198) | def dot_product_attention_weights( class FlaxBartAttention (line 276) | class FlaxBartAttention(FlaxBartAttention): method setup (line 288) | def setup(self) -> None: method __call__ (line 361) | def __call__( class GLU (line 491) | class GLU(nn.Module): method __call__ (line 501) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.... class FFN (line 559) | class FFN(nn.Module): method __call__ (line 569) | def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.... class FlaxBartEncoderLayer (line 616) | class FlaxBartEncoderLayer(nn.Module): method __call__ (line 629) | def __call__( class FlaxBartDecoderLayer (line 720) | class FlaxBartDecoderLayer(nn.Module): method __call__ (line 733) | def __call__( class FlaxBartEncoderLayerCollection (line 875) | class FlaxBartEncoderLayerCollection(nn.Module): method __call__ (line 885) | def __call__( class FlaxBartDecoderLayerCollection (line 981) | class FlaxBartDecoderLayerCollection(nn.Module): method __call__ (line 991) | def __call__( class FlaxBartEncoder (line 1114) | class FlaxBartEncoder(nn.Module): method setup (line 1126) | def setup(self): method __call__ (line 1158) | def __call__( class FlaxBartDecoder (line 1204) | class FlaxBartDecoder(nn.Module): method setup (line 1216) | def setup(self): method __call__ (line 1249) | def __call__( class FlaxBartModule (line 1302) | class FlaxBartModule(FlaxBartModule): method setup (line 1309) | def setup(self): class FlaxBartForConditionalGenerationModule (line 1329) | class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGener... method setup (line 1337) | def setup(self): method __call__ (line 1347) | def __call__( class SampleState (line 1399) | class SampleState: class FlaxSampleOutput (line 1410) | class FlaxSampleOutput(ModelOutput): class DalleBart (line 1423) | class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGenerati... method num_params (line 1439) | def num_params(self, params=None): method unscan (line 1447) | def unscan(self, params): method decode (line 1466) | def decode( method prepare_inputs_for_generation (line 1598) | def prepare_inputs_for_generation( method generate (line 1633) | def generate( method _sample (line 1811) | def _sample( FILE: src/dalle_mini/model/partitions.py function _match (line 15) | def _match(qs, ks): function _replacement_rules (line 26) | def _replacement_rules(rules): function _get_partition_rules (line 36) | def _get_partition_rules(): function set_partitions (line 58) | def set_partitions(in_dict, use_scan): FILE: src/dalle_mini/model/processor.py class DalleBartProcessorBase (line 13) | class DalleBartProcessorBase: method __init__ (line 14) | def __init__( method __call__ (line 33) | def __call__(self, text: List[str] = None): method from_pretrained (line 53) | def from_pretrained(cls, *args, **kwargs): class DalleBartProcessor (line 59) | class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase): FILE: src/dalle_mini/model/text.py class HashtagProcessor (line 21) | class HashtagProcessor: method __init__ (line 24) | def __init__(self): method __call__ (line 38) | def __call__(self, s): method _split (line 43) | def _split(self, s): function replace_person_token (line 86) | def replace_person_token(t): function fix_html (line 96) | def fix_html(t): function replace_punctuation_with_commas (line 101) | def replace_punctuation_with_commas(t): function simplify_quotes (line 105) | def simplify_quotes(t): function merge_quotes (line 109) | def merge_quotes(t): function remove_comma_numbers (line 113) | def remove_comma_numbers(t): function pre_process_dot_numbers (line 120) | def pre_process_dot_numbers(t): function post_process_dot_numbers (line 124) | def post_process_dot_numbers(t): function pre_process_quotes (line 128) | def pre_process_quotes(t): function post_process_quotes (line 135) | def post_process_quotes(t): function pre_process_dates (line 139) | def pre_process_dates(t): function post_process_dates (line 143) | def post_process_dates(t): function merge_commas (line 147) | def merge_commas(t): function add_space_after_commas (line 151) | def add_space_after_commas(t): function handle_special_chars (line 155) | def handle_special_chars(t): function expand_hashtags (line 163) | def expand_hashtags(t, hashtag_processor): function ignore_chars (line 171) | def ignore_chars(t): function remove_extra_spaces (line 176) | def remove_extra_spaces(t): function remove_repeating_chars (line 181) | def remove_repeating_chars(t): function remove_urls (line 186) | def remove_urls(t): function remove_html_tags (line 190) | def remove_html_tags(t): function remove_first_last_commas (line 194) | def remove_first_last_commas(t): function remove_wiki_ref (line 201) | def remove_wiki_ref(t): class TextNormalizer (line 206) | class TextNormalizer: method __init__ (line 209) | def __init__(self): method __call__ (line 212) | def __call__(self, t): FILE: src/dalle_mini/model/tokenizer.py class DalleBartTokenizer (line 7) | class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast): FILE: src/dalle_mini/model/utils.py class PretrainedFromWandbMixin (line 8) | class PretrainedFromWandbMixin: method from_pretrained (line 10) | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, *... FILE: tools/train/scalable_shampoo/distributed_shampoo.py class TrainingMetrics (line 58) | class TrainingMetrics: class ParameterStats (line 64) | class ParameterStats(NamedTuple): class GlobalShardedParameterStats (line 80) | class GlobalShardedParameterStats: class LocalShardedParameterStats (line 89) | class LocalShardedParameterStats: function init_training_metrics (line 102) | def init_training_metrics(num_statistics): function init_training_metrics_shapes (line 111) | def init_training_metrics_shapes(num_statistics): function init_training_metrics_pspec (line 120) | def init_training_metrics_pspec(): class ShardedShampooStats (line 124) | class ShardedShampooStats(NamedTuple): class ShampooState (line 131) | class ShampooState(NamedTuple): class InitFnState (line 136) | class InitFnState(NamedTuple): class GraftingType (line 142) | class GraftingType(enum.IntEnum): class PreconditionerType (line 151) | class PreconditionerType(enum.IntEnum): function power_iteration (line 159) | def power_iteration( function mat_power (line 218) | def mat_power( function _pth_root_difference (line 246) | def _pth_root_difference(w, alpha, beta, p): function matrix_inverse_pth_root (line 267) | def matrix_inverse_pth_root( function merge_small_dims (line 422) | def merge_small_dims(shape_to_merge, max_dim): function pad_square_matrix (line 453) | def pad_square_matrix(mat, max_size): function make_sliced_padding (line 484) | def make_sliced_padding( function pad_block_symmetric_matrix (line 530) | def pad_block_symmetric_matrix( function pad_vector (line 588) | def pad_vector(vec, max_size): function efficient_cond (line 607) | def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs): class BlockPartitioner (line 623) | class BlockPartitioner: method __init__ (line 626) | def __init__(self, param, block_size): method split_sizes (line 645) | def split_sizes(self): method partition (line 648) | def partition(self, tensor): method merge_partitions (line 660) | def merge_partitions(self, partitions): function gram_weighted_update (line 677) | def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None): class Preconditioner (line 704) | class Preconditioner: method __init__ (line 707) | def __init__( method updated_statistics_from_grad (line 734) | def updated_statistics_from_grad( method should_precondition_dims (line 777) | def should_precondition_dims(self): method shapes_for_preconditioners (line 786) | def shapes_for_preconditioners(self): method exponent_for_preconditioner (line 799) | def exponent_for_preconditioner(self): method preconditioned_grad (line 805) | def preconditioned_grad(self, grad, preconditioners): function _convert_to_parameter_stats (line 841) | def _convert_to_parameter_stats(global_stats, local_stat, convert_statis... function _convert_from_parameter_stats (line 864) | def _convert_from_parameter_stats(parameter_stats, local_stats): function _add_error_into_local_stats (line 876) | def _add_error_into_local_stats(local_stats, errors, inverse_failure_thr... function batch (line 907) | def batch(x, num_devices): function unbatch (line 914) | def unbatch(batched_values): function distributed_shampoo (line 929) | def distributed_shampoo( FILE: tools/train/scalable_shampoo/quantization_utils.py class QuantizedValue (line 27) | class QuantizedValue: method from_float_value (line 40) | def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=Fa... method quantize (line 58) | def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): method to_float (line 107) | def to_float(self): FILE: tools/train/scalable_shampoo/sm3.py class SM3State (line 37) | class SM3State(NamedTuple): class ParameterStats (line 43) | class ParameterStats(NamedTuple): function sm3 (line 50) | def sm3( FILE: tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py class SlicedSymmetricMatrix (line 29) | class SlicedSymmetricMatrix: function product_with_transpose (line 42) | def product_with_transpose( function sliced_transposed_product (line 62) | def sliced_transposed_product( function sliced_transposed_product_concat (line 130) | def sliced_transposed_product_concat( function materialize_matrix (line 156) | def materialize_matrix(symmetric_matrix): function materialize_matrix_from_concat (line 193) | def materialize_matrix_from_concat( function update_sliced_rows (line 225) | def update_sliced_rows( function num_blocks_from_total_blocks (line 258) | def num_blocks_from_total_blocks(total_blocks): function find_num_blocks (line 280) | def find_num_blocks(block_rows_concat): function slice_symmetric_matrix (line 303) | def slice_symmetric_matrix( function slice_symmetric_matrix_concat (line 334) | def slice_symmetric_matrix_concat( function sliced_matrix_diag (line 348) | def sliced_matrix_diag(mat): function diag_as_concat (line 365) | def diag_as_concat(diag, block_size): function row_abs_maxes (line 382) | def row_abs_maxes(mat): function times_vector (line 420) | def times_vector(mat, vec): FILE: tools/train/train.py class ModelArguments (line 73) | class ModelArguments: method __post_init__ (line 123) | def __post_init__(self): method get_metadata (line 134) | def get_metadata(self): method get_opt_state (line 144) | def get_opt_state(self): class DataTrainingArguments (line 178) | class DataTrainingArguments: method __post_init__ (line 292) | def __post_init__(self): class TrainingArguments (line 298) | class TrainingArguments: method __post_init__ (line 510) | def __post_init__(self): function split_params (line 569) | def split_params(data): function unsplit_params (line 587) | def unsplit_params(data): function trainable_params (line 595) | def trainable_params(data, embeddings_only): function init_embeddings (line 619) | def init_embeddings(model, params): function main (line 637) | def main():