SYMBOL INDEX (456 symbols across 77 files) FILE: common/batch.py class BatchBase (line 14) | class BatchBase(Pipelineable, abc.ABC): method as_dict (line 16) | def as_dict(self) -> Dict: method to (line 19) | def to(self, device: torch.device, non_blocking: bool = False): method record_stream (line 25) | def record_stream(self, stream: torch.cuda.streams.Stream) -> None: method pin_memory (line 29) | def pin_memory(self): method __repr__ (line 35) | def __repr__(self) -> str: method batch_size (line 42) | def batch_size(self) -> int: class DataclassBatch (line 53) | class DataclassBatch(BatchBase): method feature_names (line 55) | def feature_names(cls): method as_dict (line 58) | def as_dict(self): method from_schema (line 66) | def from_schema(name: str, schema): method from_fields (line 75) | def from_fields(name: str, fields: dict): class DictionaryBatch (line 83) | class DictionaryBatch(BatchBase, dict): method as_dict (line 84) | def as_dict(self) -> Dict: FILE: common/checkpointing/snapshot.py class Snapshot (line 15) | class Snapshot: method __init__ (line 22) | def __init__(self, save_dir: str, state: Dict[str, Any]) -> None: method step (line 28) | def step(self): method step (line 32) | def step(self, step: int) -> None: method walltime (line 36) | def walltime(self): method walltime (line 40) | def walltime(self, walltime: float) -> None: method save (line 43) | def save(self, global_step: int) -> "PendingSnapshot": method restore (line 60) | def restore(self, checkpoint: str) -> None: method get_torch_snapshot (line 80) | def get_torch_snapshot( method load_snapshot_to_weight (line 97) | def load_snapshot_to_weight( function _eval_subdir (line 123) | def _eval_subdir(checkpoint_path: str) -> str: function _eval_done_path (line 127) | def _eval_done_path(checkpoint_path: str, eval_partition: str) -> str: function is_done_eval (line 131) | def is_done_eval(checkpoint_path: str, eval_partition: str): function mark_done_eval (line 135) | def mark_done_eval(checkpoint_path: str, eval_partition: str): function step_from_checkpoint (line 139) | def step_from_checkpoint(checkpoint: str) -> int: function checkpoints_iterator (line 143) | def checkpoints_iterator(save_dir: str, seconds_to_sleep: int = 30, time... function get_checkpoint (line 177) | def get_checkpoint( function get_checkpoints (line 211) | def get_checkpoints(save_dir: str) -> List[str]: function wait_for_evaluators (line 229) | def wait_for_evaluators( FILE: common/device.py function maybe_setup_tensorflow (line 7) | def maybe_setup_tensorflow(): function setup_and_get_device (line 16) | def setup_and_get_device(tf_ok: bool = True) -> torch.device: FILE: common/filesystem/test_infer_fs.py function test_infer_fs (line 8) | def test_infer_fs(): FILE: common/filesystem/util.py function infer_fs (line 10) | def infer_fs(path: str): function is_local_fs (line 20) | def is_local_fs(fs): function is_gcs_fs (line 24) | def is_gcs_fs(fs): FILE: common/log_weights.py function weights_to_log (line 11) | def weights_to_log( function log_ebc_norms (line 47) | def log_ebc_norms( FILE: common/modules/embedding/config.py class DataType (line 10) | class DataType(str, Enum): class EmbeddingSnapshot (line 15) | class EmbeddingSnapshot(base_config.BaseConfig): class EmbeddingBagConfig (line 26) | class EmbeddingBagConfig(base_config.BaseConfig): class LargeEmbeddingsConfig (line 42) | class LargeEmbeddingsConfig(base_config.BaseConfig): class Mode (line 54) | class Mode(str, Enum): FILE: common/modules/embedding/embedding.py class LargeEmbeddings (line 13) | class LargeEmbeddings(nn.Module): method __init__ (line 14) | def __init__( method forward (line 51) | def forward( FILE: common/run_training.py function is_distributed_worker (line 13) | def is_distributed_worker(): function maybe_run_training (line 19) | def maybe_run_training( FILE: common/test_device.py function test_device (line 10) | def test_device(): FILE: common/testing_utils.py function mock_pg (line 21) | def mock_pg(): FILE: common/utils.py function _read_file (line 14) | def _read_file(f): function setup_configuration (line 19) | def setup_configuration( FILE: common/wandb.py class WandbConfig (line 8) | class WandbConfig(base_config.BaseConfig): FILE: core/config/base_config.py class BaseConfig (line 10) | class BaseConfig(pydantic.BaseModel): class Config (line 30) | class Config: method _field_data_map (line 37) | def _field_data_map(cls, field_data_name): method _one_of_check (line 47) | def _one_of_check(cls, values): method _at_most_one_of_check (line 56) | def _at_most_one_of_check(cls, values): method pretty_print (line 64) | def pretty_print(self) -> str: FILE: core/config/base_config_test.py class BaseConfigTest (line 8) | class BaseConfigTest(TestCase): method test_extra_forbidden (line 9) | def test_extra_forbidden(self): method test_one_of (line 17) | def test_one_of(self): method test_at_most_one_of (line 29) | def test_at_most_one_of(self): FILE: core/config/config_load.py function load_config_from_yaml (line 10) | def load_config_from_yaml(config_type: Type[BaseConfig], yaml_path: str): FILE: core/config/test_config_load.py class _PointlessConfig (line 10) | class _PointlessConfig(BaseConfig): function test_load_config_from_yaml (line 15) | def test_load_config_from_yaml(tmp_path): FILE: core/config/training.py class RuntimeConfig (line 11) | class RuntimeConfig(base_config.BaseConfig): class TrainingConfig (line 19) | class TrainingConfig(base_config.BaseConfig): FILE: core/custom_training_loop.py function get_new_iterator (line 29) | def get_new_iterator(iterable: Iterable): function _get_step_fn (line 48) | def _get_step_fn(pipeline, data_iterator, training: bool): function _run_evaluation (line 64) | def _run_evaluation( function train (line 92) | def train( function log_eval_results (line 259) | def log_eval_results( function only_evaluate (line 274) | def only_evaluate( FILE: core/debug_training_loop.py function train (line 20) | def train( FILE: core/loss_type.py class LossType (line 5) | class LossType(str, Enum): FILE: core/losses.py function _maybe_warn (line 11) | def _maybe_warn(reduction: str): function build_loss (line 23) | def build_loss( function get_global_loss_detached (line 36) | def get_global_loss_detached(local_loss, reduction="mean"): function build_multi_task_loss (line 62) | def build_multi_task_loss( FILE: core/metric_mixin.py class MetricMixin (line 36) | class MetricMixin: method transform (line 38) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict: method update (line 41) | def update(self, outputs: Dict[str, torch.Tensor]): class TaskMixin (line 50) | class TaskMixin: method __init__ (line 51) | def __init__(self, task_idx: int = -1, **kwargs): class StratifyMixin (line 56) | class StratifyMixin: method __init__ (line 57) | def __init__( method maybe_apply_stratification (line 65) | def maybe_apply_stratification( function prepend_transform (line 86) | def prepend_transform(base_metric: torchmetrics.Metric, transform: Calla... FILE: core/metrics.py function probs_and_labels (line 14) | def probs_and_labels( class Count (line 29) | class Count(StratifyMixin, TaskMixin, MetricMixin, tm.SumMetric): method transform (line 30) | def transform(self, outputs): class Ctr (line 38) | class Ctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method transform (line 39) | def transform(self, outputs): class Pctr (line 47) | class Pctr(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method transform (line 48) | def transform(self, outputs): class Precision (line 56) | class Precision(StratifyMixin, TaskMixin, MetricMixin, tm.Precision): method transform (line 57) | def transform(self, outputs): class Recall (line 62) | class Recall(StratifyMixin, TaskMixin, MetricMixin, tm.Recall): method transform (line 63) | def transform(self, outputs): class TorchMetricsRocauc (line 68) | class TorchMetricsRocauc(StratifyMixin, TaskMixin, MetricMixin, tm.AUROC): method transform (line 69) | def transform(self, outputs): class Auc (line 74) | class Auc(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method __init__ (line 80) | def __init__(self, num_samples, **kwargs): method transform (line 84) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: class PosRanks (line 95) | class PosRanks(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method __init__ (line 102) | def __init__(self, **kwargs): method transform (line 105) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: class ReciprocalRank (line 113) | class ReciprocalRank(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method __init__ (line 120) | def __init__(self, **kwargs): method transform (line 123) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: class HitAtK (line 131) | class HitAtK(StratifyMixin, TaskMixin, MetricMixin, tm.MeanMetric): method __init__ (line 139) | def __init__(self, k: int, **kwargs): method transform (line 143) | def transform(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, Any]: FILE: core/test_metrics.py class MockStratifierConfig (line 11) | class MockStratifierConfig: class Count (line 17) | class Count(MetricMixin, SumMetric): method transform (line 18) | def transform(self, outputs): function test_count_metric (line 25) | def test_count_metric(): function test_collections (line 38) | def test_collections(): function test_task_dependent_ctr (line 53) | def test_task_dependent_ctr(): function test_stratified_ctr (line 71) | def test_stratified_ctr(): function test_auc (line 116) | def test_auc(): function test_pos_rank (line 133) | def test_pos_rank(): function test_reciprocal_rank (line 149) | def test_reciprocal_rank(): function test_hit_k (line 165) | def test_hit_k(): FILE: core/test_train_pipeline.py class MockDataclassBatch (line 13) | class MockDataclassBatch(DataclassBatch): class MockModule (line 18) | class MockModule(torch.nn.Module): method __init__ (line 19) | def __init__(self) -> None: method forward (line 24) | def forward(self, batch: MockDataclassBatch) -> Tuple[torch.Tensor, to... function create_batch (line 30) | def create_batch(bsz: int): function test_sparse_pipeline (line 37) | def test_sparse_pipeline(): function test_amp (line 67) | def test_amp(): FILE: core/train_pipeline.py class TrainPipeline (line 41) | class TrainPipeline(abc.ABC, Generic[In, Out]): method progress (line 43) | def progress(self, dataloader_iter: Iterator[In]) -> Out: function _to_device (line 47) | def _to_device(batch: In, device: torch.device, non_blocking: bool) -> In: function _wait_for_batch (line 54) | def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Strea... class TrainPipelineBase (line 73) | class TrainPipelineBase(TrainPipeline[In, Out]): method __init__ (line 81) | def __init__( method _connect (line 96) | def _connect(self, dataloader_iter: Iterator[In]) -> None: method progress (line 103) | def progress(self, dataloader_iter: Iterator[In]) -> Out: class Tracer (line 141) | class Tracer(torch.fx.Tracer): method __init__ (line 148) | def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: method is_leaf_module (line 152) | def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: st... class TrainPipelineContext (line 159) | class TrainPipelineContext: class ArgInfo (line 168) | class ArgInfo: class PipelinedForward (line 179) | class PipelinedForward: method __init__ (line 180) | def __init__( method __call__ (line 195) | def __call__(self, *input, **kwargs) -> Awaitable: method name (line 232) | def name(self) -> str: method args (line 236) | def args(self) -> List[ArgInfo]: function _start_data_dist (line 240) | def _start_data_dist( function _get_node_args_helper (line 282) | def _get_node_args_helper( function _get_node_args (line 332) | def _get_node_args( function _get_unsharded_module_names_helper (line 349) | def _get_unsharded_module_names_helper( function _get_unsharded_module_names (line 376) | def _get_unsharded_module_names(model: torch.nn.Module) -> List[str]: function _rewrite_model (line 390) | def _rewrite_model( # noqa C901 class TrainPipelineSparseDist (line 442) | class TrainPipelineSparseDist(TrainPipeline[In, Out]): method __init__ (line 462) | def __init__( method _connect (line 506) | def _connect(self, dataloader_iter: Iterator[In]) -> None: method progress (line 525) | def progress(self, dataloader_iter: Iterator[In]) -> Out: method _sync_pipeline (line 618) | def _sync_pipeline(self) -> None: FILE: machines/environment.py function on_kf (line 11) | def on_kf(): function has_readers (line 15) | def has_readers(): function get_task_type (line 22) | def get_task_type(): function is_chief (line 28) | def is_chief() -> bool: function is_reader (line 32) | def is_reader() -> bool: function is_dispatcher (line 36) | def is_dispatcher() -> bool: function get_task_index (line 40) | def get_task_index(): function get_reader_port (line 48) | def get_reader_port(): function get_dds (line 54) | def get_dds(): function get_dds_dispatcher_address (line 64) | def get_dds_dispatcher_address(): function get_dds_worker_address (line 75) | def get_dds_worker_address(): function get_num_readers (line 87) | def get_num_readers(): function get_flight_server_addresses (line 96) | def get_flight_server_addresses(): function get_dds_journaling_dir (line 107) | def get_dds_journaling_dir(): FILE: machines/get_env.py function main (line 10) | def main(argv): FILE: machines/is_venv.py function is_venv (line 11) | def is_venv(): function _main (line 16) | def _main(): FILE: machines/list_ops.py function main (line 32) | def main(argv): FILE: metrics/aggregation.py function update_mean (line 10) | def update_mean( function stable_mean_dist_reduce_fn (line 38) | def stable_mean_dist_reduce_fn(state: torch.Tensor) -> torch.Tensor: class StableMean (line 56) | class StableMean(torchmetrics.Metric): method __init__ (line 65) | def __init__(self, **kwargs): method update (line 77) | def update(self, value: torch.Tensor, weight: Union[float, torch.Tenso... method compute (line 93) | def compute(self) -> torch.Tensor: FILE: metrics/auroc.py function _compute_helper (line 13) | def _compute_helper( class AUROCWithMWU (line 53) | class AUROCWithMWU(torchmetrics.Metric): method __init__ (line 63) | def __init__(self, label_threshold: float = 0.5, raise_missing_class: ... method update (line 81) | def update( method compute (line 101) | def compute(self) -> torch.Tensor: FILE: metrics/rce.py function _smooth (line 14) | def _smooth( function _binary_cross_entropy_with_clipping (line 27) | def _binary_cross_entropy_with_clipping( class RCE (line 54) | class RCE(torchmetrics.Metric): method __init__ (line 125) | def __init__( method update (line 152) | def update( method compute (line 169) | def compute(self) -> torch.Tensor: method reset (line 183) | def reset(self): method forward (line 191) | def forward(self, *args, **kwargs): class NRCE (line 205) | class NRCE(RCE): method __init__ (line 226) | def __init__( method update (line 246) | def update( method reset (line 275) | def reset(self): FILE: ml_logging/absl_logging.py function setup_absl_logging (line 16) | def setup_absl_logging(): FILE: ml_logging/test_torch_logging.py class Testtlogging (line 6) | class Testtlogging(unittest.TestCase): method test_warn_once (line 7) | def test_warn_once(self): FILE: ml_logging/torch_logging.py function rank_specific (line 20) | def rank_specific(logger): FILE: model.py class ModelAndLoss (line 12) | class ModelAndLoss(torch.nn.Module): method __init__ (line 15) | def __init__( method forward (line 29) | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] function maybe_shard_model (line 53) | def maybe_shard_model( function log_sharded_tensor_content (line 76) | def log_sharded_tensor_content(weight_name: str, table_name: str, weight... FILE: optimizers/config.py class PiecewiseConstant (line 10) | class PiecewiseConstant(base_config.BaseConfig): class LinearRampToConstant (line 15) | class LinearRampToConstant(base_config.BaseConfig): class LinearRampToCosine (line 22) | class LinearRampToCosine(base_config.BaseConfig): class LearningRate (line 33) | class LearningRate(base_config.BaseConfig): class OptimizerAlgorithmConfig (line 40) | class OptimizerAlgorithmConfig(base_config.BaseConfig): class AdamConfig (line 47) | class AdamConfig(OptimizerAlgorithmConfig): class SgdConfig (line 54) | class SgdConfig(OptimizerAlgorithmConfig): class AdagradConfig (line 59) | class AdagradConfig(OptimizerAlgorithmConfig): class OptimizerConfig (line 64) | class OptimizerConfig(base_config.BaseConfig): function get_optimizer_algorithm_config (line 74) | def get_optimizer_algorithm_config(optimizer_config: OptimizerConfig): FILE: optimizers/optimizer.py function compute_lr (line 16) | def compute_lr(lr_config, step): class LRShim (line 48) | class LRShim(_LRScheduler): method __init__ (line 55) | def __init__( method get_lr (line 74) | def get_lr(self): method _get_closed_form_lr (line 82) | def _get_closed_form_lr(self): function get_optimizer_class (line 86) | def get_optimizer_class(optimizer_config: OptimizerConfig): function build_optimizer (line 95) | def build_optimizer( FILE: projects/home/recap/config.py class TrainingConfig (line 11) | class TrainingConfig(config_mod.BaseConfig): class RecapConfig (line 34) | class RecapConfig(config_mod.BaseConfig): class JobMode (line 49) | class JobMode(str, Enum): FILE: projects/home/recap/data/config.py class ExplicitDateInputs (line 10) | class ExplicitDateInputs(base_config.BaseConfig): class ExplicitDatetimeInputs (line 21) | class ExplicitDatetimeInputs(base_config.BaseConfig): class DdsCompressionOption (line 32) | class DdsCompressionOption(str, Enum): class DatasetConfig (line 38) | class DatasetConfig(base_config.BaseConfig): class TruncateAndSlice (line 82) | class TruncateAndSlice(base_config.BaseConfig): class DataType (line 99) | class DataType(str, Enum): class DownCast (line 109) | class DownCast(base_config.BaseConfig): class TaskData (line 116) | class TaskData(base_config.BaseConfig): class SegDenseSchema (line 127) | class SegDenseSchema(base_config.BaseConfig): class RectifyLabels (line 142) | class RectifyLabels(base_config.BaseConfig): class ExtractFeaturesRow (line 157) | class ExtractFeaturesRow(base_config.BaseConfig): class ExtractFeatures (line 172) | class ExtractFeatures(base_config.BaseConfig): class DownsampleNegatives (line 179) | class DownsampleNegatives(base_config.BaseConfig): class Preprocess (line 194) | class Preprocess(base_config.BaseConfig): class Sampler (line 208) | class Sampler(base_config.BaseConfig): class RecapDataConfig (line 221) | class RecapDataConfig(DatasetConfig): method _validate_evaluation_tasks (line 241) | def _validate_evaluation_tasks(cls, values): FILE: projects/home/recap/data/dataset.py class RecapBatch (line 22) | class RecapBatch(DataclassBatch): method __post_init__ (line 35) | def __post_init__(self): function to_batch (line 43) | def to_batch(x, sparse_feature_names: Optional[List[str]] = None) -> Rec... function _chain (line 91) | def _chain(param, f1, f2): function _add_weights (line 103) | def _add_weights(inputs, tasks: Dict[str, TaskData]): function get_datetimes (line 130) | def get_datetimes(explicit_datetime_inputs): function get_explicit_datetime_inputs_files (line 143) | def get_explicit_datetime_inputs_files(explicit_datetime_inputs): function _map_output_for_inference (line 183) | def _map_output_for_inference( function _map_output_for_train_eval (line 198) | def _map_output_for_train_eval( function _add_weights_based_on_sampling_rates (line 216) | def _add_weights_based_on_sampling_rates(inputs, tasks: Dict[str, TaskDa... class RecapDataset (line 242) | class RecapDataset(torch.utils.data.IterableDataset): method __init__ (line 243) | def __init__( method _init_tensor_spec (line 304) | def _init_tensor_spec(self): method _create_tf_dataset (line 315) | def _create_tf_dataset(self): method _create_base_tf_dataset (line 373) | def _create_base_tf_dataset(self, batch_size: int): method _gen (line 469) | def _gen(self): method to_dataloader (line 473) | def to_dataloader(self) -> Dict[str, torch.Tensor]: method __iter__ (line 476) | def __iter__(self): FILE: projects/home/recap/data/generate_random_data.py function _generate_random_example (line 17) | def _generate_random_example( function _float_feature (line 35) | def _float_feature(value): function _int64_feature (line 39) | def _int64_feature(value): function _serialize_example (line 43) | def _serialize_example(x: Dict[str, tf.Tensor]) -> bytes: function generate_data (line 53) | def generate_data(data_path: str, config: recap_config_mod.RecapConfig): function _generate_data_main (line 70) | def _generate_data_main(unused_argv): FILE: projects/home/recap/data/preprocessors.py class TruncateAndSlice (line 11) | class TruncateAndSlice(tf.keras.Model): method __init__ (line 14) | def __init__(self, truncate_and_slice_config): method call (line 34) | def call(self, inputs, training=None, mask=None): class DownCast (line 53) | class DownCast(tf.keras.Model): method __init__ (line 59) | def __init__(self, downcast_config): method call (line 67) | def call(self, inputs, training=None, mask=None): class RectifyLabels (line 80) | class RectifyLabels(tf.keras.Model): method __init__ (line 83) | def __init__(self, rectify_label_config): method call (line 88) | def call(self, inputs, training=None, mask=None): class ExtractFeatures (line 104) | class ExtractFeatures(tf.keras.Model): method __init__ (line 107) | def __init__(self, extract_features_config): method call (line 111) | def call(self, inputs, training=None, mask=None): class DownsampleNegatives (line 119) | class DownsampleNegatives(tf.keras.Model): method __init__ (line 130) | def __init__(self, downsample_negatives_config): method call (line 134) | def call(self, inputs, training=None, mask=None): function build_preprocess (line 170) | def build_preprocess(preprocess_config, mode=config_mod.JobMode.TRAIN): FILE: projects/home/recap/data/tfe_parsing.py function create_tf_example_schema (line 14) | def create_tf_example_schema( function make_mantissa_mask (line 60) | def make_mantissa_mask(mask_length: int) -> tf.Tensor: function mask_mantissa (line 65) | def mask_mantissa(tensor: tf.Tensor, mask_length: int) -> tf.Tensor: function parse_tf_example (line 71) | def parse_tf_example( function get_seg_dense_parse_fn (line 108) | def get_seg_dense_parse_fn(data_config: recap_data_config.RecapDataConfig): FILE: projects/home/recap/data/util.py function keyed_tensor_from_tensors_dict (line 8) | def keyed_tensor_from_tensors_dict( function _compute_jagged_tensor_from_tensor (line 30) | def _compute_jagged_tensor_from_tensor(tensor: torch.Tensor) -> Tuple[to... function jagged_tensor_from_tensor (line 41) | def jagged_tensor_from_tensor(tensor: torch.Tensor) -> "torchrec.JaggedT... function keyed_jagged_tensor_from_tensors_dict (line 55) | def keyed_jagged_tensor_from_tensors_dict( function _tf_to_numpy (line 93) | def _tf_to_numpy(tf_tensor: tf.Tensor) -> np.ndarray: function _dense_tf_to_torch (line 97) | def _dense_tf_to_torch(tensor: tf.Tensor, pin_memory: bool) -> torch.Ten... function sparse_or_dense_tf_to_torch (line 109) | def sparse_or_dense_tf_to_torch( FILE: projects/home/recap/embedding/config.py class EmbeddingSnapshot (line 8) | class EmbeddingSnapshot(base_config.BaseConfig): class EmbeddingBagConfig (line 20) | class EmbeddingBagConfig(base_config.BaseConfig): class EmbeddingOptimizerConfig (line 32) | class EmbeddingOptimizerConfig(base_config.BaseConfig): class LargeEmbeddingsConfig (line 41) | class LargeEmbeddingsConfig(base_config.BaseConfig): class StratifierConfig (line 54) | class StratifierConfig(base_config.BaseConfig): class SmallEmbeddingBagConfig (line 60) | class SmallEmbeddingBagConfig(base_config.BaseConfig): class SmallEmbeddingBagConfig (line 69) | class SmallEmbeddingBagConfig(base_config.BaseConfig): class SmallEmbeddingsConfig (line 78) | class SmallEmbeddingsConfig(base_config.BaseConfig): FILE: projects/home/recap/main.py function run (line 36) | def run(unused_argv: str, data_service_dispatcher: Optional[str] = None): FILE: projects/home/recap/model/config.py class DropoutConfig (line 12) | class DropoutConfig(base_config.BaseConfig): class LayerNormConfig (line 20) | class LayerNormConfig(base_config.BaseConfig): class BatchNormConfig (line 31) | class BatchNormConfig(base_config.BaseConfig): class DenseLayerConfig (line 42) | class DenseLayerConfig(base_config.BaseConfig): class MlpConfig (line 47) | class MlpConfig(base_config.BaseConfig): class BatchNormConfig (line 54) | class BatchNormConfig(base_config.BaseConfig): class DoubleNormLogConfig (line 63) | class DoubleNormLogConfig(base_config.BaseConfig): class Log1pAbsConfig (line 71) | class Log1pAbsConfig(base_config.BaseConfig): class ClipLog1pAbsConfig (line 75) | class ClipLog1pAbsConfig(base_config.BaseConfig): class ZScoreLogConfig (line 81) | class ZScoreLogConfig(base_config.BaseConfig): class FeaturizationConfig (line 101) | class FeaturizationConfig(base_config.BaseConfig): class DropoutConfig (line 113) | class DropoutConfig(base_config.BaseConfig): class MlpConfig (line 121) | class MlpConfig(base_config.BaseConfig): class DcnConfig (line 134) | class DcnConfig(base_config.BaseConfig): class MaskBlockConfig (line 150) | class MaskBlockConfig(base_config.BaseConfig): class MaskNetConfig (line 161) | class MaskNetConfig(base_config.BaseConfig): class PositionDebiasConfig (line 167) | class PositionDebiasConfig(base_config.BaseConfig): class AffineMap (line 185) | class AffineMap(base_config.BaseConfig): class DLRMConfig (line 192) | class DLRMConfig(base_config.BaseConfig): class TaskModel (line 200) | class TaskModel(base_config.BaseConfig): class MultiTaskType (line 215) | class MultiTaskType(str, enum.Enum): class ModelConfig (line 221) | class ModelConfig(base_config.BaseConfig): method _validate_mtl (line 249) | def _validate_mtl(cls, values): FILE: projects/home/recap/model/entrypoint.py function sanitize (line 20) | def sanitize(task_name): function unsanitize (line 24) | def unsanitize(sanitized_task_name): function _build_single_task_model (line 28) | def _build_single_task_model(task: model_config_mod.TaskModel, input_sha... class MultiTaskRankingModel (line 40) | class MultiTaskRankingModel(torch.nn.Module): method __init__ (line 43) | def __init__( method forward (line 159) | def forward( function create_ranking_model (line 264) | def create_ranking_model( FILE: projects/home/recap/model/feature_transform.py function log_transform (line 13) | def log_transform(x: torch.Tensor) -> torch.Tensor: class BatchNorm (line 18) | class BatchNorm(torch.nn.Module): method __init__ (line 19) | def __init__(self, num_features: int, config: BatchNormConfig): method forward (line 23) | def forward(self, x: torch.Tensor) -> torch.Tensor: class LayerNorm (line 27) | class LayerNorm(torch.nn.Module): method __init__ (line 28) | def __init__(self, normalized_shape: Union[int, Sequence[int]], config... method forward (line 40) | def forward(self, x: torch.Tensor) -> torch.Tensor: class Log1pAbs (line 44) | class Log1pAbs(torch.nn.Module): method __init__ (line 45) | def __init__(self): method forward (line 48) | def forward(self, x: torch.Tensor) -> torch.Tensor: class InputNonFinite (line 52) | class InputNonFinite(torch.nn.Module): method __init__ (line 53) | def __init__(self, fill_value: float = 0): method forward (line 60) | def forward(self, x: torch.Tensor) -> torch.Tensor: class Clamp (line 64) | class Clamp(torch.nn.Module): method __init__ (line 65) | def __init__(self, min_value: float, max_value: float): method forward (line 76) | def forward(self, x: torch.Tensor) -> torch.Tensor: class DoubleNormLog (line 80) | class DoubleNormLog(torch.nn.Module): method __init__ (line 83) | def __init__( method forward (line 108) | def forward( function build_features_preprocessor (line 118) | def build_features_preprocessor( FILE: projects/home/recap/model/mask_net.py function _init_weights (line 8) | def _init_weights(module): class MaskBlock (line 14) | class MaskBlock(torch.nn.Module): method __init__ (line 15) | def __init__( method forward (line 44) | def forward(self, net: torch.Tensor, mask_input: torch.Tensor): class MaskNet (line 51) | class MaskNet(torch.nn.Module): method __init__ (line 52) | def __init__(self, mask_net_config: config.MaskNetConfig, in_features:... method forward (line 79) | def forward(self, inputs: torch.Tensor): FILE: projects/home/recap/model/mlp.py function _init_weights (line 9) | def _init_weights(module): class Mlp (line 15) | class Mlp(torch.nn.Module): method __init__ (line 16) | def __init__(self, in_features: int, mlp_config: MlpConfig): method forward (line 44) | def forward(self, x: torch.Tensor) -> torch.Tensor: method shared_size (line 53) | def shared_size(self): method out_features (line 57) | def out_features(self): FILE: projects/home/recap/model/model_and_loss.py class ModelAndLoss (line 7) | class ModelAndLoss(torch.nn.Module): method __init__ (line 8) | def __init__( method forward (line 25) | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] FILE: projects/home/recap/model/numeric_calibration.py class NumericCalibration (line 4) | class NumericCalibration(torch.nn.Module): method __init__ (line 5) | def __init__( method forward (line 18) | def forward(self, probs: torch.Tensor): FILE: projects/home/recap/optimizer/config.py class RecapAdamConfig (line 11) | class RecapAdamConfig(base_config.BaseConfig): class MultiTaskLearningRates (line 17) | class MultiTaskLearningRates(base_config.BaseConfig): class RecapOptimizerConfig (line 27) | class RecapOptimizerConfig(base_config.BaseConfig): FILE: projects/home/recap/optimizer/optimizer.py class RecapLRShim (line 25) | class RecapLRShim(torch.optim.lr_scheduler._LRScheduler): method __init__ (line 33) | def __init__( method get_lr (line 59) | def get_lr(self): method _get_closed_form_lr (line 67) | def _get_closed_form_lr(self): function build_optimizer (line 78) | def build_optimizer( FILE: projects/twhin/config.py class TwhinConfig (line 9) | class TwhinConfig(base_config.BaseConfig): FILE: projects/twhin/data/config.py class TwhinDataConfig (line 6) | class TwhinDataConfig(base_config.BaseConfig): FILE: projects/twhin/data/data.py function create_dataset (line 6) | def create_dataset(data_config: TwhinDataConfig, model_config: TwhinMode... FILE: projects/twhin/data/edges.py class EdgeBatch (line 17) | class EdgeBatch(DataclassBatch): class EdgesDataset (line 24) | class EdgesDataset(Dataset): method __init__ (line 27) | def __init__( method pa_to_batch (line 58) | def pa_to_batch(self, batch: pa.RecordBatch): method _to_kjt (line 72) | def _to_kjt( method to_batches (line 149) | def to_batches(self): FILE: projects/twhin/data/test_data.py function test_create_dataset (line 5) | def test_create_dataset(): FILE: projects/twhin/data/test_edges.py function test_gen (line 25) | def test_gen(): FILE: projects/twhin/metrics.py function create_metrics (line 7) | def create_metrics( FILE: projects/twhin/models/config.py class TwhinEmbeddingsConfig (line 12) | class TwhinEmbeddingsConfig(LargeEmbeddingsConfig): method embedding_dims_match (line 14) | def embedding_dims_match(cls, tables): class Operator (line 23) | class Operator(str, enum.Enum): class Relation (line 27) | class Relation(pydantic.BaseModel): class TwhinModelConfig (line 44) | class TwhinModelConfig(base_config.BaseConfig): method valid_node_types (line 50) | def valid_node_types(cls, relation, values, **kwargs): FILE: projects/twhin/models/models.py class TwhinModel (line 16) | class TwhinModel(nn.Module): method __init__ (line 17) | def __init__(self, model_config: TwhinModelConfig, data_config: TwhinD... method forward (line 33) | def forward(self, batch: EdgeBatch): function apply_optimizers (line 100) | def apply_optimizers(model: TwhinModel, model_config: TwhinModelConfig): class TwhinModelAndLoss (line 118) | class TwhinModelAndLoss(torch.nn.Module): method __init__ (line 119) | def __init__( method forward (line 138) | def forward(self, batch: "RecapBatch"): # type: ignore[name-defined] FILE: projects/twhin/models/test_models.py function twhin_model_config (line 20) | def twhin_model_config() -> TwhinModelConfig: function twhin_data_config (line 54) | def twhin_data_config() -> TwhinDataConfig: function test_twhin_model (line 67) | def test_twhin_model(): function test_unequal_dims (line 86) | def test_unequal_dims(): FILE: projects/twhin/optimizer.py function _lr_from_config (line 17) | def _lr_from_config(optimizer_config): function build_optimizer (line 26) | def build_optimizer(model: TwhinModel, config: TwhinModelConfig): FILE: projects/twhin/run.py function run (line 36) | def run( function main (line 82) | def main(argv): FILE: projects/twhin/test_optimizer.py function test_twhin_optimizer (line 15) | def test_twhin_optimizer(): FILE: reader/dataset.py class _Reader (line 27) | class _Reader(pa.flight.FlightServerBase): method __init__ (line 30) | def __init__(self, location: str, ds: "Dataset"): method do_get (line 35) | def do_get(self, _, __): class Dataset (line 48) | class Dataset(torch.utils.data.IterableDataset): method __init__ (line 51) | def __init__(self, file_pattern: str, **dataset_kwargs) -> None: method _validate_columns (line 66) | def _validate_columns(self): method serve (line 72) | def serve(self): method _create_dataset (line 76) | def _create_dataset(self): method to_batches (line 84) | def to_batches(self): method pa_to_batch (line 102) | def pa_to_batch(self, batch: pa.RecordBatch) -> DataclassBatch: method dataloader (line 105) | def dataloader(self, remote: bool = False): function get_readers (line 119) | def get_readers(num_readers_per_worker: int): FILE: reader/dds.py function maybe_start_dataset_service (line 23) | def maybe_start_dataset_service(): function register_dataset (line 59) | def register_dataset( function distribute_from_dataset_id (line 78) | def distribute_from_dataset_id( function maybe_distribute_dataset (line 99) | def maybe_distribute_dataset(dataset: tf.data.Dataset) -> tf.data.Dataset: FILE: reader/test_dataset.py function create_dataset (line 14) | def create_dataset(tmpdir): function test_dataset (line 36) | def test_dataset(tmpdir): function test_distributed_dataset (line 48) | def test_distributed_dataset(tmpdir): FILE: reader/test_utils.py function test_rr (line 4) | def test_rr(): FILE: reader/utils.py function roundrobin (line 13) | def roundrobin(*iterables): function speed_check (line 37) | def speed_check(data_loader, max_steps: int, frequency: int, peek: Optio... function pa_to_torch (line 59) | def pa_to_torch(array: pa.array) -> torch.Tensor: function create_default_pa_to_batch (line 63) | def create_default_pa_to_batch(schema) -> DataclassBatch: FILE: tools/pq.py function _create_dataset (line 40) | def _create_dataset(path: str): class PqReader (line 46) | class PqReader: method __init__ (line 47) | def __init__( method __iter__ (line 55) | def __iter__(self): method _head (line 64) | def _head(self): method bytes_per_row (line 73) | def bytes_per_row(self) -> int: method schema (line 83) | def schema(self): method head (line 86) | def head(self): method distinct (line 90) | def distinct(self):