SYMBOL INDEX (535 symbols across 36 files) FILE: clu/asynclib.py class AsyncError (line 27) | class AsyncError(Exception): class Pool (line 31) | class Pool: method __init__ (line 50) | def __init__(self, thread_name_prefix: str = "", method _reraise (line 69) | def _reraise(self) -> None: method close (line 76) | def close(self) -> None: method join (line 81) | def join(self) -> None: method queue_length (line 98) | def queue_length(self) -> int: method has_errors (line 103) | def has_errors(self) -> bool: method clear_errors (line 107) | def clear_errors(self) -> List[Exception]: method __call__ (line 113) | def __call__(self, fn: Callable): # pylint: disable=g-bare-generic FILE: clu/asynclib_test.py class AsyncWriterTest (line 23) | class AsyncWriterTest(absltest.TestCase): method test_async_execution (line 25) | def test_async_execution(self): method test_reraise (line 39) | def test_reraise(self): method test_queue_length (line 65) | def test_queue_length(self, executor_mock): method test_flush (line 95) | def test_flush(self, executor_mock): FILE: clu/checkpoint.py function safe_normpath (line 77) | def safe_normpath(path: str) -> str: function load_state_dict (line 83) | def load_state_dict(base_directory) -> Dict[str, Any]: class CheckpointInfo (line 106) | class CheckpointInfo( method initialize (line 113) | def initialize(cls, base_directory, checkpoint_name: str) -> "Checkpoi... method from_path (line 118) | def from_path(cls, checkpoint: str) -> "CheckpointInfo": method increment (line 134) | def increment(self) -> "CheckpointInfo": method __str__ (line 138) | def __str__(self): class Checkpoint (line 143) | class Checkpoint: method __init__ (line 162) | def __init__(self, method get_latest_checkpoint_to_restore_from (line 194) | def get_latest_checkpoint_to_restore_from(self): method latest_checkpoint (line 207) | def latest_checkpoint(self) -> Optional[str]: method current_checkpoint (line 220) | def current_checkpoint(self) -> Optional[str]: method _flax_path (line 240) | def _flax_path(self, checkpoint: str) -> str: method _next_checkpoint (line 243) | def _next_checkpoint(self, checkpoint: Optional[str]) -> str: method _checkpoint_number (line 249) | def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]: method _delete_future_checkpoints (line 254) | def _delete_future_checkpoints(self): method save (line 273) | def save(self, state) -> str: method restore_or_initialize (line 330) | def restore_or_initialize(self, state: T) -> T: method restore_dict (line 350) | def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, ... method _checkpoint_or_latest (line 371) | def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str: method load_state (line 378) | def load_state(self, method restore (line 410) | def restore(self, class MultihostCheckpoint (line 448) | class MultihostCheckpoint(Checkpoint): method __init__ (line 463) | def __init__(self, method get_latest_checkpoint_to_restore_from (line 503) | def get_latest_checkpoint_to_restore_from(self) -> Optional[str]: FILE: clu/checkpoint_test.py function _make_dataset (line 26) | def _make_dataset(): class TrainState (line 34) | class TrainState: class TrainStateExtended (line 39) | class TrainStateExtended: class NotTrainState (line 44) | class NotTrainState: function _checkpoint_number (line 48) | def _checkpoint_number(path): class CheckpointTest (line 54) | class CheckpointTest(tf.test.TestCase): method test_safe_normpath (line 56) | def test_safe_normpath(self): method test_initialize_mkdir (line 63) | def test_initialize_mkdir(self): method test_restores_flax_state (line 75) | def test_restores_flax_state(self): method test_load_state_dict (line 102) | def test_load_state_dict(self): method test_fails_when_restoring_subset (line 114) | def test_fails_when_restoring_subset(self): method test_fails_when_restoring_superset (line 125) | def test_fails_when_restoring_superset(self): method test_restores_tf_state (line 136) | def test_restores_tf_state(self): method test_restore_flax_alone (line 172) | def test_restore_flax_alone(self): method test_restore_dict (line 185) | def test_restore_dict(self): method test_ignores_incomplete_checkpoint (line 213) | def test_ignores_incomplete_checkpoint(self): method test_max_to_keep (line 249) | def test_max_to_keep(self): method test_checkpoint_name (line 262) | def test_checkpoint_name(self): method test_fails_if_not_registered (line 269) | def test_fails_if_not_registered(self): method test_overwrite (line 276) | def test_overwrite(self): class MultihostCheckpoint (line 312) | class MultihostCheckpoint(tf.test.TestCase): method test_initialize_mkdir (line 315) | def test_initialize_mkdir(self, process_index_mock): method test_synchronize_multiple_hosts (line 328) | def test_synchronize_multiple_hosts(self, process_index_mock): method test_preemption (line 360) | def test_preemption(self): FILE: clu/data/dataset_iterator.py class ArraySpec (line 52) | class ArraySpec: method __repr__ (line 57) | def __repr__(self): method __str__ (line 60) | def __str__(self): class DatasetIterator (line 77) | class DatasetIterator(collections.abc.Iterator): # pytype: disable=igno... method get_next (line 92) | def get_next(self) -> Element: method reset (line 99) | def reset(self): method element_spec (line 105) | def element_spec(self) -> ElementSpec: method save (line 109) | def save(self, filename: epath.Path): method restore (line 119) | def restore(self, filename: epath.Path): method load (line 129) | def load(self, filename: epath.Path): class TfDatasetIterator (line 134) | class TfDatasetIterator(DatasetIterator): method __init__ (line 137) | def __init__(self, dataset, *, checkpoint: bool): method get_next (line 174) | def get_next(self) -> Element: method __next__ (line 177) | def __next__(self) -> Element: method reset (line 180) | def reset(self): method element_spec (line 185) | def element_spec(self) -> ElementSpec: method save (line 202) | def save(self, filename: epath.Path): method restore (line 206) | def restore(self, filename: epath.Path): class PeekableDatasetIterator (line 211) | class PeekableDatasetIterator(DatasetIterator): method __init__ (line 230) | def __init__(self, it: DatasetIterator): method __next__ (line 238) | def __next__(self) -> Element: method reset (line 246) | def reset(self): method element_spec (line 254) | def element_spec(self) -> ElementSpec: method peek (line 257) | def peek(self) -> Element: method peek_async (line 270) | def peek_async(self) -> concurrent.futures.Future[Element]: method save (line 286) | def save(self, filename: epath.Path): method restore (line 290) | def restore(self, filename: epath.Path): FILE: clu/data/dataset_iterator_test.py class DatasetIteratorTest (line 28) | class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase): method _create_iterator (line 30) | def _create_iterator(self, start_index: int, checkpoint: bool = True): method test_tf_iterator (line 40) | def test_tf_iterator(self): method test_tf_iterator_save_and_load (line 53) | def test_tf_iterator_save_and_load(self): method test_tf_iterator_save_and_load_no_checkpoint (line 70) | def test_tf_iterator_save_and_load_no_checkpoint(self): method test_peekable_dataset_iterator (line 84) | def test_peekable_dataset_iterator(self): method test_peekable_dataset_iterator_async (line 92) | def test_peekable_dataset_iterator_async(self, wait: bool, peek_first:... FILE: clu/deterministic_data.py class DatasetBuilder (line 82) | class DatasetBuilder(typing_extensions.Protocol): method as_dataset (line 85) | def as_dataset( class RemainderOptions (line 92) | class RemainderOptions(enum.Enum): function _shard_read_instruction (line 111) | def _shard_read_instruction( function get_read_instruction_for_host (line 175) | def get_read_instruction_for_host( function _preprocess_with_per_example_rng (line 272) | def _preprocess_with_per_example_rng(ds: tf.data.Dataset, function pad_dataset (line 303) | def pad_dataset(dataset: tf.data.Dataset, function create_dataset (line 362) | def create_dataset(dataset_builder: DatasetBuilder, function create_distributed_dataset (line 484) | def create_distributed_dataset( FILE: clu/deterministic_data_test.py class MyDatasetBuilder (line 35) | class MyDatasetBuilder: method as_dataset (line 39) | def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: ... class FakeDatasetInfo (line 57) | class FakeDatasetInfo: method splits (line 62) | def splits(self): class DeterministicDataTest (line 69) | class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase): method test_get_read_instruction_for_host_deprecated (line 86) | def test_get_read_instruction_for_host_deprecated(self, num_examples: ... method test_get_read_instruction_for_host (line 140) | def test_get_read_instruction_for_host(self, host_id: int, host_count:... method test_get_read_instruction_balance_remainder (line 168) | def test_get_read_instruction_balance_remainder(self, host_id: int, method test_get_read_instruction_for_host_fails (line 191) | def test_get_read_instruction_for_host_fails(self, host_id: int, method test_preprocess_with_per_example_rng (line 197) | def test_preprocess_with_per_example_rng(self): method test_create_dataset_padding (line 223) | def test_create_dataset_padding(self, pad_up_to_batches, cardinality): method test_create_dataset_padding_raises_error_cardinality (line 258) | def test_create_dataset_padding_raises_error_cardinality(self): method test_pad_dataset (line 278) | def test_pad_dataset(self): method test_pad_nested_dataset (line 292) | def test_pad_nested_dataset(self): method test_same_cardinality_on_all_hosts (line 309) | def test_same_cardinality_on_all_hosts(self, num_examples: int, method test_same_cardinality_on_all_hosts_with_pad (line 326) | def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int, FILE: clu/internal/utils.py function log_activity (line 30) | def log_activity(activity_name: str): function logged_with (line 47) | def logged_with(activity_name: str): function check_param (line 57) | def check_param(value, *, ndim=None, dtype=jnp.float32): function flatten_dict (line 77) | def flatten_dict( FILE: clu/internal/utils_test.py class TestError (line 23) | class TestError(BaseException): class HelpersTest (line 28) | class HelpersTest(absltest.TestCase): method test_log_activity (line 30) | def test_log_activity( method test_log_activity_fails (line 41) | def test_log_activity_fails( method test_logged_with (line 53) | def test_logged_with(self): method test_logged_with_fails (line 66) | def test_logged_with_fails(self): method test_check_param (line 80) | def test_check_param(self): method test_flatten_dict (line 91) | def test_flatten_dict(self): FILE: clu/metric_writers/async_writer.py function _wrap_exceptions (line 39) | def _wrap_exceptions(wrapped, instance, args, kwargs): class AsyncWriter (line 51) | class AsyncWriter(interface.MetricWriter): method __init__ (line 64) | def __init__(self, method write_summaries (line 78) | def write_summaries( method write_scalars (line 86) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 90) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 94) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 98) | def write_audios( method write_texts (line 104) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 108) | def write_histograms(self, method write_pointcloud (line 116) | def write_pointcloud( method write_hparams (line 132) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 135) | def flush(self): method close (line 141) | def close(self): class AsyncMultiWriter (line 148) | class AsyncMultiWriter(multi_writer.MultiWriter): method __init__ (line 151) | def __init__(self, function ensure_flushes (line 159) | def ensure_flushes(*writers: interface.MetricWriter): FILE: clu/metric_writers/async_writer_test.py class AsyncWriterTest (line 27) | class AsyncWriterTest(tf.test.TestCase): method setUp (line 29) | def setUp(self): method test_write_summaries_async (line 34) | def test_write_summaries_async(self): method test_write_scalars_async (line 46) | def test_write_scalars_async(self): method test_write_images (line 61) | def test_write_images(self): method test_write_videos (line 68) | def test_write_videos(self): method test_write_pointcloud (line 75) | def test_write_pointcloud(self): method test_write_texts (line 96) | def test_write_texts(self): method test_ensure_flushes (line 101) | def test_ensure_flushes(self): method test_ensure_flushes_with_multiple_writers (line 117) | def test_ensure_flushes_with_multiple_writers(self): method test_flush_before_close (line 142) | def test_flush_before_close(self): method test_reraises_exception (line 147) | def test_reraises_exception(self): FILE: clu/metric_writers/interface.py class MetricWriter (line 33) | class MetricWriter(abc.ABC): method write_summaries (line 37) | def write_summaries( method write_scalars (line 53) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 65) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 83) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 104) | def write_audios( method write_texts (line 126) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 137) | def write_histograms(self, method write_pointcloud (line 156) | def write_pointcloud( method write_hparams (line 177) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 187) | def flush(self): method close (line 191) | def close(self): FILE: clu/metric_writers/logging_writer.py class LoggingWriter (line 28) | class LoggingWriter(interface.MetricWriter): method __init__ (line 31) | def __init__(self, collection: Optional[str] = None): method write_summaries (line 37) | def write_summaries( method write_scalars (line 44) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 51) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 55) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 59) | def write_audios( method write_texts (line 64) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 67) | def write_histograms(self, method write_pointcloud (line 80) | def write_pointcloud( method write_hparams (line 101) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 104) | def flush(self): method close (line 107) | def close(self): function _compute_histogram_as_tf (line 111) | def _compute_histogram_as_tf( function _get_histogram_as_string (line 152) | def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray): FILE: clu/metric_writers/logging_writer_test.py class LoggingWriterTest (line 22) | class LoggingWriterTest(tf.test.TestCase): method setUp (line 24) | def setUp(self): method test_write_scalars (line 28) | def test_write_scalars(self): method test_write_images (line 36) | def test_write_images(self): method test_write_videos (line 44) | def test_write_videos(self): method test_write_texts (line 52) | def test_write_texts(self): method test_write_histogram (line 59) | def test_write_histogram(self): method test_write_pointcloud (line 83) | def test_write_pointcloud(self): method test_write_hparams (line 106) | def test_write_hparams(self): method test_collection (line 113) | def test_collection(self): FILE: clu/metric_writers/multi_writer.py class MultiWriter (line 26) | class MultiWriter(interface.MetricWriter): method __init__ (line 29) | def __init__(self, writers: Sequence[interface.MetricWriter]): method write_summaries (line 32) | def write_summaries( method write_scalars (line 39) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 43) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 47) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 51) | def write_audios( method write_texts (line 56) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 60) | def write_histograms(self, method write_pointcloud (line 67) | def write_pointcloud( method write_hparams (line 80) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 84) | def flush(self): method close (line 88) | def close(self): FILE: clu/metric_writers/multi_writer_test.py class MultiWriterTest (line 25) | class MultiWriterTest(tf.test.TestCase): method setUp (line 27) | def setUp(self): method test_write_scalars (line 35) | def test_write_scalars(self): method test_write_pointcloud (line 52) | def test_write_pointcloud(self): FILE: clu/metric_writers/tf/summary_writer.py class SummaryWriter (line 42) | class SummaryWriter(interface.MetricWriter): method __init__ (line 45) | def __init__(self, logdir: str): method write_summaries (line 50) | def write_summaries( method write_scalars (line 61) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 66) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 73) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 78) | def write_audios( method write_texts (line 85) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 90) | def write_histograms( method write_pointcloud (line 101) | def write_pointcloud( method write_hparams (line 121) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 125) | def flush(self): method close (line 128) | def close(self): FILE: clu/metric_writers/tf/summary_writer_test.py function _load_summaries_data (line 27) | def _load_summaries_data(logdir): function _load_histograms_data (line 41) | def _load_histograms_data(logdir): function _load_scalars_data (line 60) | def _load_scalars_data(logdir: str): function _load_pointcloud_data (line 72) | def _load_pointcloud_data(logdir: str): function _load_hparams (line 87) | def _load_hparams(logdir: str): class SummaryWriterTest (line 101) | class SummaryWriterTest(tf.test.TestCase): method setUp (line 103) | def setUp(self): method test_write_summaries (line 108) | def test_write_summaries(self): method test_write_scalar (line 121) | def test_write_scalar(self): method test_write_histograms (line 129) | def test_write_histograms(self): method test_write_pointcloud (line 160) | def test_write_pointcloud(self): method test_hparams (line 178) | def test_hparams(self): method test_hparams_nested (line 187) | def test_hparams_nested(self): FILE: clu/metric_writers/torch_tensorboard_writer.py class TorchTensorboardWriter (line 32) | class TorchTensorboardWriter(interface.MetricWriter): method __init__ (line 35) | def __init__(self, logdir: str): method write_summaries (line 40) | def write_summaries( method write_scalars (line 48) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]): method write_images (line 52) | def write_images(self, step: int, images: Mapping[str, Array]): method write_videos (line 56) | def write_videos(self, step: int, videos: Mapping[str, Array]): method write_audios (line 61) | def write_audios( method write_texts (line 67) | def write_texts(self, step: int, texts: Mapping[str, str]): method write_histograms (line 72) | def write_histograms(self, method write_pointcloud (line 81) | def write_pointcloud( method write_hparams (line 95) | def write_hparams(self, hparams: Mapping[str, Any]): method flush (line 98) | def flush(self): method close (line 101) | def close(self): FILE: clu/metric_writers/torch_tensorboard_writer_test.py function _load_scalars_data (line 26) | def _load_scalars_data(logdir: str): function _load_histograms_data (line 38) | def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]: class TorchTensorboardWriterTest (line 62) | class TorchTensorboardWriterTest(tf.test.TestCase): method setUp (line 64) | def setUp(self): method test_write_scalar (line 69) | def test_write_scalar(self): method test_write_histograms (line 77) | def test_write_histograms(self): FILE: clu/metric_writers/utils.py function _is_scalar (line 46) | def _is_scalar(value: Any) -> bool: function write_values (line 56) | def write_values( function create_default_writer (line 113) | def create_default_writer( FILE: clu/metric_writers/utils_test.py class HistogramMetric (line 39) | class HistogramMetric(clu.metrics.Metric): method compute_value (line 43) | def compute_value(self): class ImageMetric (line 48) | class ImageMetric(clu.metrics.Metric): method compute_value (line 51) | def compute_value(self): class AudioMetric (line 56) | class AudioMetric(clu.metrics.Metric): method compute_value (line 60) | def compute_value(self): class TextMetric (line 65) | class TextMetric(clu.metrics.Metric): method compute_value (line 68) | def compute_value(self): class HyperParamMetric (line 73) | class HyperParamMetric(clu.metrics.Metric): method compute_value (line 76) | def compute_value(self): class SummaryMetric (line 81) | class SummaryMetric(clu.metrics.Metric): method compute_value (line 85) | def compute_value(self): function _to_summary (line 89) | def _to_summary(metrics): function _to_list_of_dicts (line 93) | def _to_list_of_dicts(d): class ONEOF (line 97) | class ONEOF(object): method __init__ (line 100) | def __init__(self, container): method __eq__ (line 107) | def __eq__(self, o): method __ne__ (line 110) | def __ne__(self, o): method __repr__ (line 113) | def __repr__(self): class MetricWriterTest (line 117) | class MetricWriterTest(tf.test.TestCase, parameterized.TestCase): method test_write (line 119) | def test_write(self): method test_create_default_writer_summary_writer_is_added (line 198) | def test_create_default_writer_summary_writer_is_added(self): FILE: clu/metrics.py class FromFunCallable (line 76) | class FromFunCallable(Protocol): method __call__ (line 79) | def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]: function _assert_same_shape (line 88) | def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray): class Metric (line 97) | class Metric: method from_model_output (line 133) | def from_model_output(cls: type[M], *args, **kwargs) -> M: method merge (line 137) | def merge(self: M, other: M) -> M: method _reduce_merge (line 160) | def _reduce_merge(self: M, other: M) -> M: method compute (line 163) | def compute(self) -> jnp.ndarray: method empty (line 168) | def empty(cls: type[M]) -> M: method compute_value (line 172) | def compute_value(self) -> clu.values.Value: method reduce (line 176) | def reduce(self: M) -> M: method from_fun (line 235) | def from_fun(cls, fun: FromFunCallable): # No way to annotate return ... method from_output (line 314) | def from_output(cls, name: str): # No way to annotate return type class CollectingMetric (line 358) | class CollectingMetric(Metric): method empty (line 416) | def empty(cls) -> CollectingMetric: method merge (line 419) | def merge(self, other: CollectingMetric) -> CollectingMetric: method reduce (line 434) | def reduce(self) -> CollectingMetric: method compute (line 440) | def compute(self): # No return type annotation, so subclasses can ove... method from_outputs (line 444) | def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]: class _ReductionCounter (line 465) | class _ReductionCounter(Metric): method empty (line 471) | def empty(cls) -> _ReductionCounter: method merge (line 474) | def merge(self, other: _ReductionCounter) -> _ReductionCounter: function _check_reduction_counter_ndim (line 478) | def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter): class Collection (line 490) | class Collection: method create (line 512) | def create(cls, **metrics: type[Metric]) -> type[Collection]: method create_collection (line 538) | def create_collection(cls, **metrics: Metric) -> Collection: method empty (line 566) | def empty(cls: type[C]) -> C: method _from_model_output (line 576) | def _from_model_output(cls: type[C], **kwargs) -> C: method single_from_model_output (line 587) | def single_from_model_output(cls: type[C], **kwargs) -> C: method gather_from_model_output (line 602) | def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs... method merge (line 617) | def merge(self: C, other: C) -> C: method reduce (line 624) | def reduce(self: C) -> C: method compute (line 656) | def compute(self) -> dict[str, jnp.ndarray]: method compute_values (line 665) | def compute_values(self) -> dict[str, clu.values.Value]: method unreplicate (line 674) | def unreplicate(self: C) -> C: class LastValue (line 692) | class LastValue(Metric): method __init__ (line 707) | def __init__( # pytype: disable=missing-parameter # jnp-array method empty (line 743) | def empty(cls) -> LastValue: method from_model_output (line 747) | def from_model_output( method merge (line 757) | def merge(self, other: LastValue) -> LastValue: method _reduce_merge (line 761) | def _reduce_merge(self, other: LastValue) -> LastValue: method value (line 770) | def value(self) -> jnp.ndarray: method compute (line 775) | def compute(self) -> Any: function _broadcast_masks (line 779) | def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None): class Average (line 801) | class Average(Metric): method empty (line 820) | def empty(cls) -> Average: method from_model_output (line 824) | def from_model_output( method merge (line 837) | def merge(self, other: Average) -> Average: method compute (line 844) | def compute(self) -> Any: class Std (line 849) | class Std(Metric): method empty (line 861) | def empty(cls) -> Std: method from_model_output (line 868) | def from_model_output( method merge (line 882) | def merge(self, other: Std) -> Std: method compute (line 890) | def compute(self) -> Any: class Accuracy (line 906) | class Accuracy(Average): method from_model_output (line 916) | def from_model_output( FILE: clu/metrics_test.py class CollectingMetricAccuracy (line 32) | class CollectingMetricAccuracy( method compute (line 35) | def compute(self): class Collection (line 45) | class Collection(metrics.Collection): class CollectionMixed (line 51) | class CollectionMixed(metrics.Collection): class MetricsTest (line 56) | class MetricsTest(parameterized.TestCase): method setUp (line 58) | def setUp(self): method make_compute_metric (line 117) | def make_compute_metric(self, metric_class, reduce, jit=True): method test_metric_last_value_reduce (line 151) | def test_metric_last_value_reduce(self): method test_metric_last_value (line 177) | def test_metric_last_value(self): method test_metric_last_value_legacy_kwarg_value (line 192) | def test_metric_last_value_legacy_kwarg_value(self): method test_metric_last_value_tree_manipulation (line 198) | def test_metric_last_value_tree_manipulation(self): method test_from_fun_with_single_output (line 213) | def test_from_fun_with_single_output(self): method test_from_fun_with_mapping_output (line 229) | def test_from_fun_with_mapping_output(self): method test_average_masked (line 262) | def test_average_masked(self, values, mask, expected_result): method test_merge_asserts_shape (line 284) | def test_merge_asserts_shape(self, metric_cls): method test_accuracy (line 296) | def test_accuracy(self, reduce): method test_last_value_asserts_shape (line 301) | def test_last_value_asserts_shape(self): method test_loss_average (line 313) | def test_loss_average(self, reduce): method test_loss_std (line 328) | def test_loss_std(self, reduce): method test_collection_create (line 341) | def test_collection_create(self): method test_collection_create_custom_mask (line 350) | def test_collection_create_custom_mask(self): method test_collection_create_collection (line 374) | def test_collection_create_collection(self): method test_collection_single (line 394) | def test_collection_single(self, masked): method test_collection_gather (line 418) | def test_collection_gather(self, masked, all_gather_mock): method test_collection_gather_pmap (line 442) | def test_collection_gather_pmap(self, masked): method test_collection_asserts_replication (line 455) | def test_collection_asserts_replication(self): method test_collecting_metric (line 466) | def test_collecting_metric(self): method test_collecting_metric_reduce (line 480) | def test_collecting_metric_reduce(self): method test_collecting_metric_async (line 486) | def test_collecting_metric_async(self): method test_collecting_metric_tracer (line 505) | def test_collecting_metric_tracer(self): method test_collection_mixed_async (line 512) | def test_collection_mixed_async(self): method test_metric_empty_types_doesnt_cause_retrace (line 530) | def test_metric_empty_types_doesnt_cause_retrace(self): method test_tensor_aggregation_metrics_with_masks (line 569) | def test_tensor_aggregation_metrics_with_masks( FILE: clu/parameter_overview.py class _ParamRow (line 32) | class _ParamRow: class _ParamRowWithSharding (line 40) | class _ParamRowWithSharding(_ParamRow): class _ParamRowWithStats (line 45) | class _ParamRowWithStats(_ParamRow): class _ParamRowWithStatsAndSharding (line 51) | class _ParamRowWithStatsAndSharding(_ParamRowWithStats): function _mean_std_jit (line 56) | def _mean_std_jit(x): function _mean_std (line 60) | def _mean_std(x): function flatten_dict (line 66) | def flatten_dict( function _count_parameters (line 82) | def _count_parameters(params: _ParamsContainer) -> int: function _parameters_size (line 88) | def _parameters_size(params: _ParamsContainer) -> int: function count_parameters (line 98) | def count_parameters(params: _ParamsContainer) -> int: function _make_row (line 104) | def _make_row(name, value) -> _ParamRow: function _make_row_with_sharding (line 120) | def _make_row_with_sharding(name, value) -> _ParamRowWithSharding: function _make_row_with_stats (line 132) | def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats: function _make_row_with_stats_and_sharding (line 143) | def _make_row_with_stats_and_sharding( function _get_parameter_rows (line 154) | def _get_parameter_rows( function _default_table_value_formatter (line 209) | def _default_table_value_formatter(value): function make_table (line 221) | def make_table( function _get_parameter_overview (line 286) | def _get_parameter_overview( function get_parameter_overview (line 310) | def get_parameter_overview( function _log_parameter_overview (line 347) | def _log_parameter_overview( function log_parameter_overview (line 368) | def log_parameter_overview( FILE: clu/parameter_overview_test.py class CNN (line 72) | class CNN(nn.Module): method __call__ (line 75) | def __call__(self, x): class JaxParameterOverviewTest (line 79) | class JaxParameterOverviewTest(absltest.TestCase): method test_count_parameters_empty (line 81) | def test_count_parameters_empty(self): method test_count_parameters (line 84) | def test_count_parameters(self): method test_get_parameter_overview_empty (line 92) | def test_get_parameter_overview_empty(self): method test_get_parameter_overview (line 98) | def test_get_parameter_overview(self): method test_get_parameter_overview_shape_dtype_struct (line 126) | def test_get_parameter_overview_shape_dtype_struct(self): method test_printing_bool (line 134) | def test_printing_bool(self): FILE: clu/periodic_actions.py function _squareit (line 44) | def _squareit(x): function _format_secs (line 49) | def _format_secs(secs: float): class PeriodicAction (line 65) | class PeriodicAction(abc.ABC): method __init__ (line 74) | def __init__(self, method _init_and_check (line 99) | def _init_and_check(self, step: int, t: float): method _should_trigger (line 112) | def _should_trigger(self, step: int, t: float) -> bool: method _after_apply (line 123) | def _after_apply(self, step: int, t: float): method __call__ (line 128) | def __call__(self, step: int, t: Optional[float] = None) -> bool: method _apply (line 150) | def _apply(self, step: int, t: float): class ReportProgress (line 154) | class ReportProgress(PeriodicAction): method __init__ (line 157) | def __init__(self, method set_persistent_notes (line 197) | def set_persistent_notes(self, message: str): method _should_trigger (line 201) | def _should_trigger(self, step: int, t: float) -> bool: method _apply (line 205) | def _apply(self, step: int, t: float): method timed (line 227) | def timed(self, name: str, wait_jax_async_dispatch: bool = True): class Profile (line 304) | class Profile(PeriodicAction): method __init__ (line 309) | def __init__( method _should_trigger (line 350) | def _should_trigger(self, step: int, t: float) -> bool: method _apply (line 364) | def _apply(self, step: int, t: float): method _start_session (line 368) | def _start_session(self): method _end_session (line 376) | def _end_session(self, url: Optional[str]): class ProfileAllHosts (line 385) | class ProfileAllHosts(PeriodicAction): method __init__ (line 390) | def __init__(self, method _should_trigger (line 419) | def _should_trigger(self, step: int, t: float) -> bool: method _apply (line 422) | def _apply(self, step: int, t: float): method _start_session (line 426) | def _start_session(self): method _end_session (line 435) | def _end_session(self, url: Optional[str], *, step: int): class PeriodicCallback (line 443) | class PeriodicCallback(PeriodicAction): method __init__ (line 446) | def __init__(self, method __call__ (line 476) | def __call__(self, step: int, t: Optional[float] = None, **kwargs) -> ... method get_last_callback_result (line 488) | def get_last_callback_result(self): method _apply (line 492) | def _apply(self, step, t, **kwargs): FILE: clu/periodic_actions_test.py class ReportProgressTest (line 26) | class ReportProgressTest(parameterized.TestCase): method test_every_steps (line 28) | def test_every_steps(self): method test_every_secs (line 50) | def test_every_secs(self): method test_without_num_train_steps (line 72) | def test_without_num_train_steps(self): method test_with_persistent_notes (line 83) | def test_with_persistent_notes(self): method test_unknown_cardinality (line 96) | def test_unknown_cardinality(self): method test_called_every_step (line 107) | def test_called_every_step(self): method test_named (line 121) | def test_named(self, wait_jax_async_dispatch, mock_time): method test_write_metrics (line 156) | def test_write_metrics(self, time_mock): class DummyProfilerSession (line 175) | class DummyProfilerSession: method __init__ (line 178) | def __init__(self): method start_session (line 183) | def start_session(self): method end_session_and_get_url (line 186) | def end_session_and_get_url(self, tag): class ProfileTest (line 191) | class ProfileTest(absltest.TestCase): method test_every_steps (line 195) | def test_every_steps(self, mock_time, mock_profiler): class ProfileAllHostsTest (line 224) | class ProfileAllHostsTest(absltest.TestCase): method test_every_steps (line 227) | def test_every_steps(self, mock_profiler): class PeriodicCallbackTest (line 247) | class PeriodicCallbackTest(absltest.TestCase): method test_every_steps (line 249) | def test_every_steps(self): method test_every_secs (line 267) | def test_every_secs(self, mock_time): method test_on_steps (line 281) | def test_on_steps(self): method test_async_execution (line 290) | def test_async_execution(self): method test_error_async_is_forwarded (line 309) | def test_error_async_is_forwarded(self): method test_function_without_step_and_time (line 325) | def test_function_without_step_and_time(self): FILE: clu/platform/__init__.py function work_unit (line 35) | def work_unit() -> WorkUnit: FILE: clu/platform/interface.py class ArtifactType (line 22) | class ArtifactType(enum.Enum): class WorkUnit (line 31) | class WorkUnit(abc.ABC): method experiment_id (line 42) | def experiment_id(self): method id (line 47) | def id(self): method name (line 51) | def name(self): method set_notes (line 63) | def set_notes(self, msg: str): method set_task_status (line 67) | def set_task_status(self, msg: str): method create_artifact (line 71) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any, FILE: clu/platform/local.py class LocalWorkUnit (line 26) | class LocalWorkUnit(WorkUnit): method experiment_id (line 30) | def experiment_id(self): method id (line 35) | def id(self): method set_notes (line 39) | def set_notes(self, msg: str): method set_task_status (line 43) | def set_task_status(self, msg: str): method create_artifact (line 47) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any, FILE: clu/preprocess_spec.py class PreprocessOp (line 77) | class PreprocessOp(Protocol): method __call__ (line 90) | def __call__(self, features: Features) -> Features: class MapTransform (line 95) | class MapTransform(abc.ABC): method __new__ (line 109) | def __new__(cls, *args, **kwargs): method __call__ (line 121) | def __call__(self, features: D) -> D: method _transform (line 130) | def _transform(self, features: FlatFeatures) -> FlatFeatures: class RandomMapTransform (line 135) | class RandomMapTransform(MapTransform, abc.ABC): method __call__ (line 149) | def __call__(self, features: D) -> D: method _transform (line 162) | def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatF... class FilterTransform (line 167) | class FilterTransform(abc.ABC): method __call__ (line 169) | def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset: method _predicate (line 175) | def _predicate(self, features: FlatFeatures) -> tf.Tensor: function get_all_ops (line 179) | def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]: function _jax_supported_tf_types (line 204) | def _jax_supported_tf_types(): class OnlyJaxTypes (line 214) | class OnlyJaxTypes: method __call__ (line 228) | def __call__(self, features: Features) -> Features: class PreprocessFn (line 252) | class PreprocessFn: method __call__ (line 265) | def __call__(self, features: Features) -> Features: method __add__ (line 280) | def __add__(self, other: "PreprocessFn") -> "PreprocessFn": method __getitem__ (line 289) | def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn": function _get_op_class (line 298) | def _get_op_class( function _parse_single_preprocess_op (line 315) | def _parse_single_preprocess_op( function parse (line 359) | def parse(spec: str, function _describe_features (line 374) | def _describe_features(features: Features) -> str: FILE: clu/preprocess_spec_test.py class ToFloat (line 27) | class ToFloat: method __call__ (line 29) | def __call__(self, features: Features) -> Features: class Rescale (line 34) | class Rescale: method __call__ (line 38) | def __call__(self, features: Features) -> Features: class AddRandomInteger (line 45) | class AddRandomInteger(preprocess_spec.RandomMapTransform): method _transform (line 47) | def _transform(self, features, seed): class PreprocessSpecTest (line 55) | class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase): method test_no_arguments (line 58) | def test_no_arguments(self): method test_positional_argument (line 63) | def test_positional_argument(self): method test_keyword_argument (line 69) | def test_keyword_argument(self): method test_invalid_op_name (line 75) | def test_invalid_op_name(self): method test_invalid_spec (line 82) | def test_invalid_spec(self): method test_pos_and_kw_arg (line 87) | def test_pos_and_kw_arg(self): method test_parsing_empty_string (line 95) | def test_parsing_empty_string(self): method test_multi_op_spec (line 100) | def test_multi_op_spec(self): method test_two_tensors (line 105) | def test_two_tensors(self): method test_only_jax_types (line 114) | def test_only_jax_types(self): method test_only_jax_types_nested_inputs (line 128) | def test_only_jax_types_nested_inputs(self): method test_not_only_jax_types (line 139) | def test_not_only_jax_types(self): method test_add_preprocess_fn (line 145) | def test_add_preprocess_fn(self): method test_slice_preprocess_fn (line 156) | def test_slice_preprocess_fn(self): method test_random_map_transform (line 166) | def test_random_map_transform(self): FILE: clu/profiler.py function start (line 29) | def start(logdir: str, options=None): function stop (line 41) | def stop() -> Optional[str]: function collect (line 49) | def collect(logdir: str, FILE: clu/values.py class Value (line 31) | class Value(Protocol): class Summary (line 41) | class Summary(Value): class Scalar (line 47) | class Scalar(Value): class Image (line 52) | class Image(Value): class Audio (line 65) | class Audio(Value): class Text (line 81) | class Text(Value): class Histogram (line 86) | class Histogram(Value): class HyperParam (line 93) | class HyperParam(Value):