Showing preview only (1,216K chars total). Download the full file or copy to clipboard to get everything.
Repository: google/CommonLoopUtils
Branch: main
Commit: 5150136b245e
Files: 52
Total size: 1.2 MB
Directory structure:
gitextract_g9bj2g2j/
├── .github/
│ └── workflows/
│ ├── build.yml
│ └── python-publish.yml
├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── clu/
│ ├── __init__.py
│ ├── asynclib.py
│ ├── asynclib_test.py
│ ├── checkpoint.py
│ ├── checkpoint_test.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── dataset_iterator.py
│ │ └── dataset_iterator_test.py
│ ├── deterministic_data.py
│ ├── deterministic_data_test.py
│ ├── internal/
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── metric_writers/
│ │ ├── __init__.py
│ │ ├── async_writer.py
│ │ ├── async_writer_test.py
│ │ ├── interface.py
│ │ ├── logging_writer.py
│ │ ├── logging_writer_test.py
│ │ ├── multi_writer.py
│ │ ├── multi_writer_test.py
│ │ ├── summary_writer.py
│ │ ├── tf/
│ │ │ ├── __init__.py
│ │ │ ├── summary_writer.py
│ │ │ └── summary_writer_test.py
│ │ ├── torch_tensorboard_writer.py
│ │ ├── torch_tensorboard_writer_test.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── metrics.py
│ ├── metrics_test.py
│ ├── parameter_overview.py
│ ├── parameter_overview_test.py
│ ├── periodic_actions.py
│ ├── periodic_actions_test.py
│ ├── platform/
│ │ ├── __init__.py
│ │ ├── interface.py
│ │ └── local.py
│ ├── preprocess_spec.py
│ ├── preprocess_spec_test.py
│ ├── profiler.py
│ ├── run_pytest.google.sh
│ └── values.py
├── clu_synopsis.ipynb
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/build.yml
================================================
# This workflow will install Python dependencies, run tests and lint.
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Build
on:
push:
branches:
- main
- 'test_*'
pull_request:
branches:
- main
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11']
steps:
- name: Cancel previous
uses: styfle/cancel-workflow-action@0.8.0
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .
pip install .[test]
- name: Test with pytest and generate coverage report
run: |
pytest .
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
file: ./coverage.xml
================================================
FILE: .github/workflows/python-publish.yml
================================================
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Upload Python Package
on:
release:
types: [created]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
================================================
FILE: AUTHORS
================================================
# This is the list of Common Loop Utils significant contributors.
#
# This does not necessarily list everyone who has contributed code,
# especially since many employees of one corporation may be contributing.
# To see the full list of contributors, see the revision history in
# source control.
Google LLC
================================================
FILE: CHANGELOG.md
================================================
# Changelog
## v0.0.1-alpha.1
Initial PyPi Release
Current list of modules:
- `clu.checkpoint`
- `clu.deterministic_training`
- `clu.metric_writers`
- `clu.periodic_actions`
- `clu.platform`
- `clu.profiler`
## v0.0.1-alpha.2
- Adds `metrics` module and some minor changes.
## v0.0.1a3
- Added `metric_writers.TorchTensorboardWriter`
## v0.0.2
- Added preprocess_spec.
- Improvements to periodic_actions.
## v0.0.3
- `metric_writers`: Lets `SummaryWriter` write nested dictionaries.
- `internal`: Adds `async.Pool`.
- `preprocess_spec`: Support nested dictionaries.
- `profile`: Use JAX profiler APIs instead of TF profiler APIs.
## v0.0.4
`deterministic_data`
- Support non-positive input value for pad_up_to_batches.
- Support padding dataset when data dimension is unknown.
- Support TFDS specs in get_read_instruction_for_host.
- Allow changing drop_remainder for batching.
- Add RemainderOptions in deterministic_data.
`metric_writers`
- Support multiple writers in metric_writers.ensure_flushes.
`metrics`
- Makes internal.flatten_dict() work with ConfigDicts.
- Forwards mask model output to metrics created via `Metric.from_output()`.
- Forwards mask model output to metrics created via `Metric.from_fun()`.
- Added `Collections.unreplicate()`, `Collections.create()`.
`periodic_actions`
- Formats long time strings in '{days}d{hours}h{mins}m' format.
`preprocess_spec`
- Make feature description of features in PreprocessFn more compact.
- Better type check in `preprocess_spec.get_all_ops()`.
Documentation:
- Added `clu_synopsis.ipynb` Colab
## v0.0.5
- Log error instead of failing when `profiler.start()` raises an exception.
- Makes `periodic_actions.ProgressUpdate` show total number of steps.
- Makes `AsyncWriter` non-blocking wrt JAX async computations.
- Adds `clu_synopsis.ipynb` Colab as initial documentation.
- Restore Checkpoint without providing the state
- Makes `PreprocessFn` addable.
- Allow n-dimensional arrays (and masks) to be passed to Metrics.Average().
- Support slicing `PreprocessFn`.
## v0.0.6
- Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`.
This will be the last release supporting Python 3.6.
## v0.0.7
- Moves `clu.internal.asynclib` to `clu.asynclib`.
- Adds methods for writing raw tensors and audio to `MetricWriter`.
- Adds `clu.values` to annotate arrays with a modality.
- Adds `clu.data.DatasetIterator` - a generic interface between input
pipelines and training loops.
- Fixes various issues with `clu.metrics`.
This will be the last release supporting Python 3.7.
## v0.0.9
- Fix pytype failures related to teaching pytype about NumPy scalar types.
- Fix a couple of docstring typos.
- Updates README and clu_synposis.ipynb
Last release before dropping support for Python 3.8 and 3.9
## v0.0.10
- `clu.parameter_overview` now supports JAX global arrays.
- Various small fixes in `clu.metrics` module.
- Removed some tensorflow dependencies.
## v0.0.11
- Removes numpy version pin
- Adds sharding annotations, dtype, total bytes to `parameter_overview`
- Makes `clu.metrics.Std` support same shapes as `clu.metrics.Average`
## v0.0.12
- Switch from `jax.tree_map` (deprecated since JAX 0.4.26) to
`jax.tree_util.tree_map`.
- Improvements to parameter overview.
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
At this time we are focused on supporting research done by Google Research and
are not accepting patches.
You are however free to start of fork of the project for your purposes as
permitted by the license.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
<https://cla.developers.google.com/> to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# CLU - Common Loop Utils
This repository contains common functionality for writing ML training loops. The
goal is to make trainings loops short and readable (but moving common tasks to
small libraries) without removing the flexibility required for research.
To get started, check out this Colab:
https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb
If you're looking for usage examples, see:
https://github.com/google/flax/tree/main/examples
You can also find answers to common questions about CLU on Flax Github
discussions page:
https://github.com/google/flax/discussions
Note: As this point we are not accepting contributions. Please fork the
repository if you want to extend the libraries for your use case.
================================================
FILE: clu/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: clu/asynclib.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for async function calls."""
import collections
import concurrent.futures
import functools
import sys
import threading
from typing import Callable, List, Optional
from absl import logging
class AsyncError(Exception):
"""An exception that wraps another exception that ocurred asynchronously."""
class Pool:
"""Pool for wrapping functions to be executed asynchronously.
Synopsis:
from clu.internal import asynclib
pool = asynclib.Pool()
@pool
def fn():
time.sleep(1)
future = fn()
print(future.result())
fn() # This could re-raise an exception from the first execution.
print(len(pool)) # Would print "1" because there is one function in flight.
pool.flush() # This could re-raise an exception from the second execution.
"""
def __init__(self, thread_name_prefix: str = "",
max_workers: Optional[int] = None):
"""Creates a new pool that decorates functions for async execution.
Args:
thread_name_prefix: See documentation of `ThreadPoolExecutor`.
max_workers: See documentation of `ThreadPoolExecutor`. The default `None`
optimizes for parallelizability using the number of CPU cores. If you
specify `max_workers=1` you the async calls are executed in the same
order they have been scheduled.
"""
self._pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers, thread_name_prefix=thread_name_prefix)
self._max_workers = max_workers
self._thread_name_prefix = thread_name_prefix
self._errors = collections.deque()
self._errors_mutex = threading.Lock()
self._queue_length = 0
def _reraise(self) -> None:
if self._errors:
with self._errors_mutex:
exc_info = self._errors.popleft()
exc = exc_info[1].with_traceback(exc_info[2])
raise AsyncError(f"Error '{exc}' occurred ASYNCHRONOUSLY.") from exc
def close(self) -> None:
"""Closes this pool & raise a pending exception (if needed)."""
self._pool.shutdown(wait=True)
self._reraise()
def join(self) -> None:
"""Blocks until all functions are processed.
The pool can be used to schedule more functions after calling this function,
but there might be more exceptions
Side-effect:
If any of the functions raised an exception, then the first of these
exceptions is reraised.
"""
self._pool.shutdown(wait=True)
self._pool = concurrent.futures.ThreadPoolExecutor(
max_workers=self._max_workers,
thread_name_prefix=self._thread_name_prefix)
self._reraise()
@property
def queue_length(self) -> int:
"""Returns the number of functions that have not returned yet."""
return self._queue_length
@property
def has_errors(self) -> bool:
"""Returns True if there are any pending errors."""
return bool(self._errors)
def clear_errors(self) -> List[Exception]:
"""Clears all pending errors and returns them as a (possibly empty) list."""
with self._errors_mutex:
errors, self._errors = self._errors, collections.deque()
return list(errors)
def __call__(self, fn: Callable): # pylint: disable=g-bare-generic
"""Returns an async version of fn.
The function will be executed by this class's ThreadPoolExecutor. Any errors
will be stored and re-raised next time any function is called that is
executed through this pool.
Note that even if there was a previous error, the function is still
scheduled upon re-execution of the wrapper returned by this function.
Args:
fn: Function to be wrapped.
Returns:
An async version of `fn`. The return value of that async version will be
a future (unless an exception was re-raised).
"""
def inner(*args, **kwargs):
def trap_errors(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception as e:
with self._errors_mutex:
self._errors.append(sys.exc_info())
logging.exception("Error in producer thread for %s",
self._thread_name_prefix)
raise e
finally:
self._queue_length -= 1
self._queue_length += 1
if not self.has_errors:
return self._pool.submit(trap_errors, *args, **kwargs)
self._pool.submit(trap_errors, *args, **kwargs)
self._reraise()
if isinstance(fn.__name__, str):
# Regular function.
return functools.wraps(fn)(inner)
# Mock or another weird function that fails with functools.wraps().
return inner
================================================
FILE: clu/asynclib_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clu.asynclib."""
from unittest import mock
from absl.testing import absltest
from clu import asynclib
class AsyncWriterTest(absltest.TestCase):
def test_async_execution(self):
pool = asynclib.Pool()
counter = 0
@pool
def fn(counter_increment, return_value):
nonlocal counter
counter += counter_increment
return return_value
future = fn(1, return_value=2)
self.assertEqual(counter, 1)
self.assertEqual(future.result(), 2)
def test_reraise(self):
pool = asynclib.Pool()
@pool
def error():
raise ValueError("test")
error()
self.assertTrue(pool.has_errors)
with self.assertRaisesRegex(asynclib.AsyncError, "test"):
pool.join()
self.assertFalse(pool.has_errors)
@pool
def noop():
...
error()
self.assertTrue(pool.has_errors)
with self.assertRaisesRegex(asynclib.AsyncError, "test"):
noop()
self.assertFalse(pool.has_errors)
pool.join()
@mock.patch("concurrent.futures.ThreadPoolExecutor")
def test_queue_length(self, executor_mock):
pool_mock = mock.Mock()
in_flight = []
def execute_one():
in_flight.pop(0)()
def submit(fn, *args, **kwargs):
in_flight.append(lambda: fn(*args, **kwargs))
pool_mock.submit = submit
executor_mock.return_value = pool_mock
pool = asynclib.Pool()
@pool
def noop():
...
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
noop()
self.assertEqual(pool.queue_length, 2)
execute_one()
self.assertEqual(pool.queue_length, 1)
execute_one()
self.assertEqual(pool.queue_length, 0)
@mock.patch("concurrent.futures.ThreadPoolExecutor")
def test_flush(self, executor_mock):
pool_mock = mock.Mock()
pool_mock._in_flight = None
def execute_one():
pool_mock._in_flight.pop(0)()
def submit(fn, *args, **kwargs):
pool_mock._in_flight.append(lambda: fn(*args, **kwargs))
def create_pool(max_workers, thread_name_prefix):
del max_workers
del thread_name_prefix
pool_mock._in_flight = []
return pool_mock
def shutdown(wait=False):
if wait:
while pool_mock._in_flight:
execute_one()
pool_mock._in_flight = None
pool_mock.submit = submit
executor_mock.side_effect = create_pool
pool_mock.shutdown.side_effect = shutdown
pool = asynclib.Pool()
@pool
def noop():
...
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
noop()
pool.join()
self.assertEqual(pool.queue_length, 0)
noop()
self.assertEqual(pool.queue_length, 1)
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/checkpoint.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple checkpointing library for TF2/Flax.
The class `Checkpoint` is a simple wrapper around `tf.train.Checkpoint` that
also stores a `flax.struct.dataclass` instance in the same directory.
Synopsis:
from clu import checkpoint
import flax
@flax.struct.dataclass
class TrainState:
optimizer: flax.optim.Optimizer
step: int
ds = load_tf_dataset()
ds_iter = iter(ds)
ckpt = checkpoint.MultihostCheckpoint(base_directory, dict(ds_iter=ds_iter))
optimizer = create_flax_optimizer()
state = TrainState(optimizer=optimizer, step=0)
state = ckpt.restore_or_initialize(state) # Also restores `ds_iter`.
initial_step = int(state.step) + 1
# Need to replicate all data when training with multiple accelerators.
state = flax.jax_utils.replicate(state)
for step in range(initial_step, steps + 1):
state = update_step(state, next(ds_iter))
ckpt.save(flax.jax_utils.unreplicate(state))
Loading the model e.g. in a Colab:
from clu import checkpoint
import flax
from . import mnist_lib
state_dict = checkpoint.load_state_dict(base_directory)
params = state_dict['optimizer']['target']['params']
module = mnist_lib.MyArchitecture.partial(num_classes=10)
model = flax.deprecated.nn.Model(module, params)
"""
import collections
import os
import re
from typing import Any, Dict, Optional, TypeVar
from absl import logging
from clu.internal import utils
import flax
import jax
import tensorflow as tf
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
T = TypeVar("T")
SCHEME_RE = re.compile("^(?P<scheme>[a-z][a-z0-9.+-]+://)?(?P<path>.*)", re.I)
def safe_normpath(path: str) -> str:
"""Normalizes path safely to get around `gfile.glob()` limitations."""
d = SCHEME_RE.match(path).groupdict() # pytype: disable=attribute-error # re-none
return (d["scheme"] or "") + os.path.normpath(d["path"])
def load_state_dict(base_directory) -> Dict[str, Any]:
"""Restores `state` as dictionary from the latest checkpoint.
Synopsis:
data = checkpoint.load_state_dict(base_directory)
params = data['optimizer']['target']['params']
module = mnist_lib.MyArchitecture.partial(num_classes=10)
model = flax.deprecated.nn.Model(module, params)
Args:
base_directory: Directory from which the checkpoints should be restored. See
`Checkpoint.__init__()`.
Returns:
The deserialized Flax data, as a dictionary.
Raises:
FileNotFoundError: If there is no checkpoint to restore.
"""
return Checkpoint(base_directory).load_state(state=None)
class CheckpointInfo(
collections.namedtuple("CheckpointInfo", ("prefix", "number"))):
"""Helper class to parse a TensorFlow checkpoint path."""
CHECKPOINT_REGEX = r"^(?P<prefix>.*)-(?P<number>\d+)"
@classmethod
def initialize(cls, base_directory, checkpoint_name: str) -> "CheckpointInfo":
"""Creates a first CheckpointInfo (number=1)."""
return cls(f"{base_directory}/{checkpoint_name}", 1)
@classmethod
def from_path(cls, checkpoint: str) -> "CheckpointInfo":
"""Parses a checkpoint.
Args:
checkpoint: A checkpoint prefix, as can be found in the
`.latest_checkpoint` property of a `tf.train.CheckpointManager`.
Returns:
An instance of `CheckpointInfo` that represents `checkpoint`.
"""
m = re.match(cls.CHECKPOINT_REGEX, checkpoint)
if m is None:
RuntimeError(f"Invalid checkpoint format: {checkpoint}")
d = m.groupdict() # pytype: disable=attribute-error
return cls(d["prefix"], int(d["number"]))
def increment(self) -> "CheckpointInfo":
"""Returns a new CheckpointInfo with `number` increased by one."""
return CheckpointInfo(self.prefix, self.number + 1)
def __str__(self):
"""Does the opposite of `.from_path()`."""
return f"{self.prefix}-{self.number}"
class Checkpoint:
"""A utility class for storing and loading TF2/Flax checkpoints.
Both the state of a `tf.data.Dataset` iterator and a `flax.struct.dataclass`
are stored on disk in the following files:
- {directory}/checkpoint
- {directory}/ckpt-{number}.index
- {directory}/ckpt-{number}.data@*
- {directory}/ckpt-{number}.flax
Where {number} starts at 1 is then incremented by 1 for every new checkpoint.
The last file is the `flax.struct.dataclass`, serialized in Messagepack
format. The other files are explained in more detail in the Tensorflow
documentation:
https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
"""
def __init__(self,
base_directory: str,
tf_state: Optional[Dict[str, Any]] = None,
*,
max_to_keep: int = 5,
checkpoint_name: str = "ckpt"):
"""Initializes a Checkpoint with a dictionary of TensorFlow Trackables.
Args:
base_directory: Directory under which the checkpoints will be stored. Use
a different base_directory in every task.
tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
example a dataset iterator.
max_to_keep: Number of checkpoints to keep in the directory. If there are
more checkpoints than specified by this number, then the oldest
checkpoints are removed.
checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
"""
if tf_state is None:
tf_state = dict()
base_directory = safe_normpath(base_directory)
self.base_directory = base_directory
self.max_to_keep = max_to_keep
self.checkpoint_name = checkpoint_name
self.tf_checkpoint = tf.train.Checkpoint(**tf_state)
self.tf_checkpoint_manager = tf.train.CheckpointManager(
self.tf_checkpoint,
base_directory,
max_to_keep=max_to_keep,
checkpoint_name=checkpoint_name)
self.restored_from = None
def get_latest_checkpoint_to_restore_from(self):
"""Returns the latest checkpoint to restore from.
In the current implementation, this method simply returns the attribute
`latest_checkpoint`.
Subclasses can override this method to provide an alternative checkpoint to
restore from, for example for synchronization across multiple checkpoint
directories.
"""
return self.latest_checkpoint
@property
def latest_checkpoint(self) -> Optional[str]:
"""Latest checkpoint, see `tf.train.CheckpointManager.latest_checkpoint`.
Returns:
A string to the latest checkpoint. Note that this string is path-like but
it does not really describe a file, but rather a set of files that are
constructed from this string, by appending different file extensions. The
returned value is `None` if there is no previously stored checkpoint in
`base_directory` specified to `__init__()`.
"""
return self.tf_checkpoint_manager.latest_checkpoint
@property
def current_checkpoint(self) -> Optional[str]:
"""Returns current checkpoint.
Note that after instance creation this will point to "ckpt-0", which does
not actually exist. After the first save (either via `.save()` or via
`.restore_or_initialize()`) it will point to "ckpt-1". When the checkpoint
is loaded from a specific checkpoint (via `.restore(state, checkpoint)`)
then this property can be different from `.latest_checkpoint`.
Returns:
A string refering to the current checkpoint. See `.latest_checkpoint` for
a description of the format.
"""
latest_checkpoint = self.latest_checkpoint
if latest_checkpoint is None:
return None
checkpoint_info = CheckpointInfo.from_path(latest_checkpoint)
number = self.tf_checkpoint.save_counter.numpy()
return str(checkpoint_info._replace(number=number))
def _flax_path(self, checkpoint: str) -> str:
return "{}.flax".format(checkpoint)
def _next_checkpoint(self, checkpoint: Optional[str]) -> str:
if checkpoint is None:
return str(
CheckpointInfo.initialize(self.base_directory, self.checkpoint_name))
return str(CheckpointInfo.from_path(checkpoint).increment())
def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]:
if checkpoint is None:
return None
return CheckpointInfo.from_path(checkpoint).number
def _delete_future_checkpoints(self):
"""Deletes checkpoints that are newer than the currently loaded checkpoint.
This happens when the checkpoint was initialized from a checkpoint that was
not the latest checkpoint (e.g. when recovering from a pre-emption in a
`MultihostCheckpoint` where some workers finished writing their checkpoints
and others didn't).
"""
checkpoint = self.current_checkpoint
while True:
checkpoint = self._next_checkpoint(checkpoint)
paths = tf.io.gfile.glob(f"{checkpoint}.*")
if not paths:
break
for path in paths:
logging.info("Cleaning up future checkpoint file '%s'", path)
tf.io.gfile.remove(path)
@utils.logged_with("Checkpoint.save()")
def save(self, state) -> str:
"""Saves a new checkpoints in the directory.
Note that if the checkpoint was restored from an earlier checkpoint than the
latest available, then saving the checkpoint will and/or delete any
checkpoints later than the restored one.
For example, if there are checkpoints `(1, 2, 3)` and then checkpoint `1`
is restored, then calling `.save()` on that restored checkpoint will result
in `2` being overwritten and `3` being deleted.
This overwriting/deleting behavior allows for seamless integration with
`MultihostCheckpoint` after pre-emption (i.e. one of the workers might have
stored one more checkpoint, but that checkpoint is only available on that
one worker and must be overwritten when the training continues).
After such an overwrite, the attributes `.current_checkpoint` and
`.latest_checkpoint` will point to newly written checkpoint (in above case
`2`), but the list `.tf_checkpoint_manager.checkpoints` might be out of sync
and should not be used.
Args:
state: Flax checkpoint to be stored.
Returns:
The checkpoint identifier ({base_directory}/ckpt-{number}).
"""
self._delete_future_checkpoints()
next_checkpoint = self._next_checkpoint(self.current_checkpoint)
flax_path = self._flax_path(next_checkpoint)
logging.info("Storing next checkpoint '%s'", next_checkpoint)
if not tf.io.gfile.exists(self.base_directory):
tf.io.gfile.makedirs(self.base_directory)
with tf.io.gfile.GFile(flax_path, "wb") as f:
f.write(flax.serialization.to_bytes(state))
checkpoints_before_save = set(self.tf_checkpoint_manager.checkpoints)
# Write Tensorflow data last. This way Tensorflow checkpoint generation
# logic will make sure to only commit checkpoints if they complete
# successfully. A previously written `flax_path` would then simply be
# overwritten next time.
self.tf_checkpoint_manager.save()
# Clean up stale Flax. Tensorflow automatically does remove checkpoints
# older than `max_to_keep`, so we do the same for the Flax checkpoints.
stale_checkpoints = checkpoints_before_save - set(
self.tf_checkpoint_manager.checkpoints)
for checkpoint in stale_checkpoints:
if tf.io.gfile.exists(self._flax_path(checkpoint)):
tf.io.gfile.remove(self._flax_path(checkpoint))
assert self.current_checkpoint == next_checkpoint, (
"Expected next_checkpoint to match .current_checkpoint: "
f"{next_checkpoint} != {self.current_checkpoint}")
return self.current_checkpoint
@utils.logged_with("Checkpoint.restore_or_initialize()")
def restore_or_initialize(self, state: T) -> T:
"""Restores from the latest checkpoint, or creates a first checkpoint.
Args:
state : A data structure to be stored or to serve as a template. If the
checkpoint is restored (and not initialized), then the fields of `state`
must match the data previously stored. See
`flax.serialization.from_state_dict()` for details.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
"""
checkpoint = self.get_latest_checkpoint_to_restore_from()
if checkpoint is not None:
return self.restore(state, checkpoint)
logging.info("Storing initial version.")
self.save(state)
return state
def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, Any]:
"""Restores last checkpoint and returns `state` as dictionary.
The only difference between this method and `.restore()` is the return type
annotation.
Args:
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
return self.restore(state=None, checkpoint=checkpoint)
def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str:
if checkpoint is None:
checkpoint = self.get_latest_checkpoint_to_restore_from()
if checkpoint is None:
raise FileNotFoundError(f"No checkpoint found at {self.base_directory}")
return checkpoint
def load_state(self,
state: Optional[T],
checkpoint: Optional[str] = None) -> T:
"""Restores Flax state the latest checkpoint.
As opposed to `.restore()`, this function only reads the Flax checkpint and
does not read the (potentially very large) TensorFlow state.
Args:
state : Template data structure that will serve as a template for the
returned state. If the loaded data does not match that template, then an
exception is raised. It's also possible to specify `state=None`, in
which case a dictionary will be returned. See
`flax.serialization.from_state_dict()` for details.
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint))
if not tf.io.gfile.exists(flax_path):
raise FileNotFoundError(f"Checkpoint {checkpoint} does not exist")
with tf.io.gfile.GFile(flax_path, "rb") as f:
return flax.serialization.from_bytes(state, f.read())
def restore(self,
state: Optional[T],
checkpoint: Optional[str] = None) -> T:
"""Restores from the latest checkpoint.
Similar to `restore_or_initialize()`, but raises a `FileNotFoundError` if
there is no checkpoint.
Args:
state : Template data structure that will serve as a template for the
returned state. If the loaded data does not match that template, then an
exception is raised. It's also possible to specify `state=None`, in
which case a dictionary will be returned. See
`flax.serialization.from_state_dict()` for details.
checkpoint: Checkpoint name that should be restored. Defaults to latest
available checkpoint. See `.latest_checkpoint` for a description of the
format of this string.
Returns:
The restored `state` object. Note that all TensorFlow `Trackable`s in
`tf_state` (see `__init__()`) are also updated.
Raises:
FileNotFoundError: If specified checkpoint does not exist, or if there
is no checkpoint to restore in case no checkpoint was specified.
"""
checkpoint = self._checkpoint_or_latest(checkpoint)
logging.info("Restoring checkpoint: %s", checkpoint)
state = self.load_state(state, checkpoint)
self.tf_checkpoint.restore(checkpoint)
logging.info("Restored save_counter=%d restored_checkpoint=%s",
self.tf_checkpoint.save_counter.numpy(),
checkpoint)
self.restored_from = checkpoint
return state
class MultihostCheckpoint(Checkpoint):
"""An subclass of `Checkpoint` that synchronizes between multiple JAX hosts.
If the training split across multiple hosts, then the following race condition
can occur : If a host is pre-empted while writing a checkpoint, then the other
hosts will only be restarted with a small delay, and at that point they
probably already have finished writing their checkpoint. Upon restart, the
host that was interrupted while writing the checkpoint will load the latest
fully written checkpoint, which will be out of sync with the other hosts that
successfully wrote one more checkpoint.
This class also allows to specify a `multihost_base_directory` that is
identical for all hosts and will be used to drive a host-specific directory.
"""
def __init__(self,
multihost_base_directory: str,
tf_state: Optional[Dict[str, Any]] = None,
*,
host_id: Optional[int] = None,
max_to_keep: int = 5,
checkpoint_name: str = "ckpt"):
"""Initializes a MultihostCheckpoint with a dict of TensorFlow Trackables.
Args:
multihost_base_directory: Directory that will be used to construct a
host-specific `base_directory` under which the checkpoints will be
stored. Usually a directory *within* the work unit's workdirectory (e.g.
`f"{workdir}/checkpoints`). One directory per host will be created at
the same level as this base directory labeled
`f"{multihost_base_directory}-{host_id}"`.
tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
example a dataset iterator.
host_id: Host ID used to construct the `base_directory`. Taken from
`jax.process_index()` if not specified.
max_to_keep: Number of checkpoints to keep in the directory. If there are
more checkpoints than specified by this number, then the oldest
checkpoints are removed.
checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
"""
if max_to_keep < 2:
raise ValueError("Requires multiple checkpoints (max_to_keep>=2).")
multihost_base_directory = multihost_base_directory.rstrip("/")
self.multihost_base_directory = multihost_base_directory
if host_id is None:
host_id = jax.process_index()
base_directory = f"{multihost_base_directory}-{host_id}"
super().__init__(
base_directory,
tf_state,
max_to_keep=max_to_keep,
checkpoint_name=checkpoint_name)
@utils.logged_with(
"MultihostCheckpoint.get_latest_checkpoint_to_restore_from()")
def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:
"""Returns the latest checkpoint available on all hosts."""
base_directory_glob = f"{self.multihost_base_directory}-*"
base_directories = tf.io.gfile.glob(base_directory_glob)
if self.base_directory not in base_directories:
logging.info("%s not in %s", self.base_directory, base_directories)
return None
checkpoints = {}
common_numbers = None
all_numbers = set()
for base_directory in base_directories:
checkpoint_manager = tf.train.CheckpointManager(
tf.train.Checkpoint(),
base_directory,
max_to_keep=self.max_to_keep,
checkpoint_name=self.checkpoint_name)
numbers = [
CheckpointInfo.from_path(checkpoint).number
for checkpoint in checkpoint_manager.checkpoints
]
checkpoints[base_directory] = dict(
zip(numbers, checkpoint_manager.checkpoints))
numbers = set(numbers)
if common_numbers is None:
common_numbers = numbers
else:
common_numbers &= numbers
all_numbers |= numbers
logging.info(
"Checked checkpoint base_directories: %s - common_numbers=%s "
"- exclusive_numbers=%s", base_directories, common_numbers,
all_numbers.difference(common_numbers))
if not common_numbers:
return None
highest_number = sorted(common_numbers)[-1]
return checkpoints[self.base_directory][highest_number]
================================================
FILE: clu/checkpoint_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for clu.checkpoint."""
import os
import tempfile
from unittest import mock
from clu import checkpoint
import flax
import tensorflow as tf
def _make_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
features = dict(x=inputs, y=labels)
return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2)
@flax.struct.dataclass
class TrainState:
step: int
@flax.struct.dataclass
class TrainStateExtended:
step: int
name: str
class NotTrainState:
pass
def _checkpoint_number(path):
*parts, number = path.split("-")
del parts
return int(number)
class CheckpointTest(tf.test.TestCase):
def test_safe_normpath(self):
self.assertEqual(checkpoint.safe_normpath("./test_dir"), "test_dir")
self.assertEqual(checkpoint.safe_normpath(".//test_dir"), "test_dir")
self.assertEqual(checkpoint.safe_normpath("gs://test_dir"), "gs://test_dir")
self.assertEqual(
checkpoint.safe_normpath("gs://test_dir/"), "gs://test_dir")
def test_initialize_mkdir(self):
base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
self.assertIsNone(ckpt.current_checkpoint)
self.assertIsNone(ckpt.latest_checkpoint)
self.assertFalse(os.path.isdir(base_dir))
state = ckpt.restore_or_initialize(state)
self.assertIsNotNone(ckpt.latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, ckpt.current_checkpoint)
self.assertTrue(os.path.isdir(base_dir))
def test_restores_flax_state(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=2)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
state = TrainState(step=2)
# Stores step=2.
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore(state)
self.assertEqual(state.step, 2)
state = TrainState(step=3)
# Stores step=3
path2 = ckpt.save(state)
self.assertEqual(_checkpoint_number(path2), 3)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore(state, path)
self.assertEqual(state.step, 2)
def test_load_state_dict(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
# Initializes.
state = ckpt.restore_or_initialize(state)
# Load via load_state_dict().
flax_dict = checkpoint.load_state_dict(base_dir)
self.assertEqual(flax_dict, dict(step=1))
with self.assertRaisesRegex(FileNotFoundError, r"^No checkpoint found"):
checkpoint.load_state_dict(tempfile.mkdtemp())
def test_fails_when_restoring_subset(self):
base_dir = tempfile.mkdtemp()
state = TrainStateExtended(step=1, name="test")
ckpt = checkpoint.Checkpoint(base_dir)
# Initialixes with TrainStateExtended.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores with TrainState.
with self.assertRaisesRegex(ValueError, r"^Unknown field"):
state = ckpt.restore_or_initialize(state)
def test_fails_when_restoring_superset(self):
base_dir = tempfile.mkdtemp()
ckpt = checkpoint.Checkpoint(base_dir)
state = TrainState(step=0)
# Initialixes with TrainState.
state = ckpt.restore_or_initialize(state)
state = TrainStateExtended(step=1, name="test")
# Restores with TrainStateExtended.
with self.assertRaisesRegex(ValueError, r"^Missing field"):
state = ckpt.restore_or_initialize(state)
def test_restores_tf_state(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
features0 = next(ds_iter) # Advance iterator by one.
del features0
state = TrainState(step=1)
# Initialize at features1.
state = ckpt.restore_or_initialize(state)
features1 = next(ds_iter)
features2 = next(ds_iter)
self.assertNotAllEqual(features1["x"], features2["x"])
self.assertNotAllEqual(features1["y"], features2["y"])
# Restore at features1.
state = ckpt.restore_or_initialize(state)
features1_restored = next(ds_iter)
self.assertAllEqual(features1["x"], features1_restored["x"])
self.assertAllEqual(features1["y"], features1_restored["y"])
# Save at features2.
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
features2 = next(ds_iter)
features3 = next(ds_iter)
self.assertNotAllEqual(features2["x"], features3["x"])
self.assertNotAllEqual(features2["y"], features3["y"])
# Restore at features2.
state = ckpt.restore_or_initialize(state)
features2_restored = next(ds_iter)
self.assertAllEqual(features2["x"], features2_restored["x"])
self.assertAllEqual(features2["y"], features2_restored["y"])
# Restore at features2 as dictionary.
state = ckpt.restore_dict()
features2_restored = next(ds_iter)
self.assertAllEqual(features2["x"], features2_restored["x"])
self.assertAllEqual(features2["y"], features2_restored["y"])
def test_restore_flax_alone(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
state = TrainState(step=1)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
ckpt = checkpoint.Checkpoint(base_dir)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
def test_restore_dict(self):
base_dir = tempfile.mkdtemp()
ds_iter = iter(_make_dataset())
ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
with self.assertRaisesRegex(FileNotFoundError, r"No checkpoint found at"):
ckpt.restore_dict()
with self.assertRaisesRegex(FileNotFoundError,
r"Checkpoint invalid does not exist"):
ckpt.restore_dict(checkpoint="invalid")
state = TrainState(step=1)
ckpt.save(state)
state_dict = ckpt.restore_dict()
self.assertEqual(state_dict, dict(step=1))
first_checkpoint = ckpt.latest_checkpoint
new_state = TrainState(step=2)
ckpt.save(new_state)
self.assertEqual(
ckpt.restore_dict(checkpoint=first_checkpoint),
dict(step=1))
self.assertEqual(ckpt.restore_dict(), dict(step=2))
self.assertEqual(
ckpt.restore_dict(checkpoint=ckpt.latest_checkpoint),
dict(step=2))
def test_ignores_incomplete_checkpoint(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir)
# Initializes.
state = ckpt.restore_or_initialize(state)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
state = TrainState(step=2)
# Failed save : step=2 is stored, but TensorFlow checkpoint fails.
ckpt.tf_checkpoint_manager.save = None
with self.assertRaisesRegex(TypeError,
r"'NoneType' object is not callable"):
ckpt.save(state)
files = os.listdir(base_dir)
self.assertIn("ckpt-2.flax", files)
self.assertNotIn("ckpt-2.index", files)
ckpt = checkpoint.Checkpoint(base_dir)
state = TrainState(step=0)
# Restores step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
# Stores step=2.
state = TrainState(step=2)
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
files = os.listdir(base_dir)
self.assertIn("ckpt-2.flax", files)
self.assertIn("ckpt-2.index", files)
state = TrainState(step=0)
# Restores step=2.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 2)
def test_max_to_keep(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
state = ckpt.restore_or_initialize(state)
files1 = os.listdir(base_dir)
state = TrainState(step=2)
path = ckpt.save(state)
self.assertEqual(_checkpoint_number(path), 2)
files2 = os.listdir(base_dir)
self.assertEqual(len(files1), len(files2))
self.assertNotEqual(files1, files2)
def test_checkpoint_name(self):
base_dir = tempfile.mkdtemp()
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, checkpoint_name="test")
path = ckpt.save(state)
self.assertIn("test", path)
def test_fails_if_not_registered(self):
base_dir = tempfile.mkdtemp()
not_state = NotTrainState()
ckpt = checkpoint.Checkpoint(base_dir)
with self.assertRaisesRegex(TypeError, r"serialize"):
ckpt.restore_or_initialize(not_state)
def test_overwrite(self):
base_dir = tempfile.mkdtemp()
tf_step = tf.Variable(1)
state = TrainState(step=1)
ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
# Initialize step=1.
state = ckpt.restore_or_initialize(state)
self.assertEqual(state.step, 1)
self.assertEqual(tf_step.numpy(), 1)
checkpoint_info = checkpoint.CheckpointInfo.from_path(
ckpt.current_checkpoint)
# Stores steps 2, 3, 4, 5
for _ in range(4):
tf_step.assign_add(1)
state = state.replace(step=state.step + 1)
ckpt.save(state)
latest_checkpoint = str(checkpoint_info._replace(number=5))
self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
# Restores at step=1
ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
state = ckpt.restore(state, checkpoint=str(checkpoint_info))
self.assertEqual(state.step, 1)
self.assertEqual(tf_step.numpy(), 1)
self.assertNotEqual(ckpt.current_checkpoint, ckpt.latest_checkpoint)
self.assertEqual(ckpt.current_checkpoint, str(checkpoint_info))
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
# Overwrites step=2, deletes 3, 4, 5.
tf_step.assign_add(1)
state = state.replace(step=state.step + 1)
ckpt.save(state)
latest_checkpoint = str(checkpoint_info._replace(number=2))
self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
class MultihostCheckpoint(tf.test.TestCase):
@mock.patch("jax.process_index")
def test_initialize_mkdir(self, process_index_mock):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
process_index_mock.return_value = 0
base_dir = f"{multihost_base_dir}-0"
ckpt = checkpoint.MultihostCheckpoint(multihost_base_dir)
self.assertIsNone(ckpt.latest_checkpoint)
self.assertFalse(os.path.isdir(base_dir))
state = ckpt.restore_or_initialize(state)
self.assertIsNotNone(ckpt.latest_checkpoint)
self.assertTrue(os.path.isdir(base_dir))
@mock.patch("jax.process_index")
def test_synchronize_multiple_hosts(self, process_index_mock):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
process_index_mock.return_value = 0
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
process_index_mock.return_value = 1
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
# Initialize both at step=1.
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
# Update both at step=2.
state_0 = state_0.replace(step=2)
ckpt_0.save(state_0)
state_1 = state_1.replace(step=2)
ckpt_1.save(state_1)
# Update ckpt_1 at step=3.
state_1 = state_1.replace(step=3)
ckpt_1.save(state_1)
# Reload both at step=2.
process_index_mock.return_value = 0
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
process_index_mock.return_value = 1
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
self.assertEqual(ckpt_0.latest_checkpoint,
ckpt_0.get_latest_checkpoint_to_restore_from())
self.assertNotEqual(ckpt_1.latest_checkpoint,
ckpt_1.get_latest_checkpoint_to_restore_from())
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 2)
self.assertEqual(state_1.step, 2)
def test_preemption(self):
multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
state = TrainState(step=1)
state0 = state.replace(step=0)
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
# Initialize both at step=1.
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Restore both at step=1.
state_0 = ckpt_0.restore_or_initialize(state0)
state_1 = ckpt_1.restore_or_initialize(state0)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Update only ckpt_0 to step=2.
state_0 = state_0.replace(step=2)
ckpt_0.save(state_0)
# Load both checkpoints at last common step=1.
ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
state_0 = ckpt_0.restore_or_initialize(state)
state_1 = ckpt_1.restore_or_initialize(state)
self.assertEqual(state_0.step, 1)
self.assertEqual(state_1.step, 1)
# Store both at step=2.
state_0 = state_0.replace(step=2)
state_1 = state_1.replace(step=2)
ckpt_0.save(state_0)
ckpt_1.save(state_1)
# Restore both at step=2.
state_0 = ckpt_0.restore_or_initialize(state0)
state_1 = ckpt_1.restore_or_initialize(state0)
self.assertEqual(state_0.step, 2)
self.assertEqual(state_1.step, 2)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/data/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DatasetIterator is an interface for input pipelines."""
# pylint: disable=g-multiple-import
# pylint: disable=unused-import
from clu.data.dataset_iterator import (
Array,
ArraySpec,
DatasetIterator,
Element,
ElementSpec,
TfDatasetIterator,
PeekableDatasetIterator,
PyTree,
)
================================================
FILE: clu/data/dataset_iterator.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interface for dataset iterators.
This module provides the DatasetIterator interface. This intention is that
several frameworks providing datasets can implement this interface without
knowing anything about the framework used for the model and the training loop.
Likewise can training loops assume to get an DatasetIterator object and do not
need to care about the specifics of the input pipelines.
This modules does not depend on TensorFlow. The interface is generic and users
don't have to use `tf.data` to construct a DatasetIterator. However, if they
use `tf.data` they can simply wrap their `tf.data.Dataset` object with
`TfDatasetIterator` to satisfy the interface.
"""
from __future__ import annotations
import abc
import collections.abc
import concurrent.futures
import dataclasses
import os
import threading
import typing
from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from absl import logging
from clu import asynclib
from etils import epath
import jax.numpy as jnp # Just for type checking.
import numpy as np
import numpy.typing as npt
Array = Union[np.ndarray, jnp.ndarray]
# Sizes of dimensions, None means the dimension size is unknown.
Shape = Tuple[Optional[int], ...]
@dataclasses.dataclass(frozen=True)
class ArraySpec:
"""Describes an array via it's dtype and shape."""
dtype: npt.DTypeLike
shape: Shape
def __repr__(self):
return f"ArraySpec(dtype={np.dtype(self.dtype).name}, shape={self.shape})"
def __str__(self):
return f"{np.dtype(self.dtype).name}{list(self.shape)}"
# Elements are PyTrees with NumPy/JAX arrays.
# Anything can be a PyTree (it's either a container or leaf). We define
# PyTree[T] as a PyTree where all leaves are of type T.
# See https://jax.readthedocs.io/en/latest/pytrees.html.
L = TypeVar("L") # pylint: disable=invalid-name
PyTree = Union[L, Sequence["PyTree[L]"], Mapping[str, "PyTree[L]"]]
Element = PyTree[Array]
ElementSpec = PyTree[ArraySpec]
class DatasetIterator(collections.abc.Iterator): # pytype: disable=ignored-abstractmethod
"""Generic interface for iterating over a dataset.
This does not support __getitem__ since it cannot be implemented efficiently
for many datasets. However datasets should allow starting the iterator from
an arbitrary position.
The element_spec property helps consumers to validate the input without
reading data. This is similar to `tf.data.Dataset.element_spec`.
Subclasses may decided to not read/write checkpoints if their state is
sufficiently tracked externally (e.g. input pipelines that can be correctly
restarted from the step number).
"""
def get_next(self) -> Element:
"""Returns the next element."""
logging.error(
"DatasetIterator.get_next() is deprecated. Please use next().")
# Subclasses should implement __next__() and remove calls to get_next().
return next(self)
def reset(self):
"""Resets the iterator back to the beginning."""
raise NotImplementedError
@property
@abc.abstractmethod
def element_spec(self) -> ElementSpec:
"""Returns the spec elements."""
raise NotImplementedError()
def save(self, filename: epath.Path):
"""Saves the state of the iterator to a file.
This should only handle this iterator - not iterators in other processes.
Args:
filename: Name of the checkpoint.
"""
raise NotImplementedError()
def restore(self, filename: epath.Path):
"""Restores the iterator from a file (if available).
This should only handle this iterator - not iterators in other processes.
Args:
filename: Name of the checkpoint.
"""
raise NotImplementedError()
def load(self, filename: epath.Path):
logging.error("DatasetIterator.load() is deprecated. Please use restore().")
return self.restore(filename)
class TfDatasetIterator(DatasetIterator):
"""DatasetIterator for wrapping a `tf.data.Dataset`."""
def __init__(self, dataset, *, checkpoint: bool):
"""Wraps `tf.data.Dataset` object into the `DatasetIterator` interface.
Warning: Do not wrap this interator to do asynchronous prefetching if you
use `checkpoint=True` (default). tf.data iterators must be saved()
synchronously.
Args:
dataset: The dataset to wrap. Elements are converted to NumPy arrays but
no additional prefetching is done. tf.data should automatically prefetch
elements (to CPU memory).
checkpoint: Whether to checkpoint the dataset iterator object.
Checkpointing dataset iterators is required for handling job
pre-emptions but depending on your input pipeline can result in very
large checkpoints. If set to False save() and load() are no-ops.
"""
try:
# Since this is the only class in this module using TF we only import
# tensorflow if needed.
if typing.TYPE_CHECKING:
tf = Any
else:
import tensorflow as tf # pylint: disable=g-import-not-at-top
except ImportError as e:
raise RuntimeError("When using TfDatasetIterator your binary must "
"depend on //third_party/py/tensorflow.") from e
self._tf = tf
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("`dataset` must be an instance of `tf.data.Dataset` "
f"but got {type(dataset)}.")
self._dataset = dataset
self._checkpoint = checkpoint
assert self.element_spec # Verify element spec.
self.iterator = iter(dataset)
self._ckpt = tf.train.Checkpoint(ds=self.iterator)
def get_next(self) -> Element:
return next(self)
def __next__(self) -> Element:
return {k: np.asarray(v) for k, v in next(self.iterator).items()}
def reset(self):
self.iterator = iter(self._dataset)
self._ckpt = self._tf.train.Checkpoint(ds=self.iterator)
@property
def element_spec(self) -> ElementSpec:
element_spec = self._dataset.element_spec
if not isinstance(element_spec, dict):
raise ValueError("Dataset elements must be flat dictionaries but got "
f"{element_spec}.")
invalid_features = [
k for k, v in element_spec.items()
if not isinstance(v, self._tf.TensorSpec)
]
if invalid_features:
raise ValueError(f"Features {invalid_features} are not tensors. Dataset "
"elements must be flat dictionaries of tensors.")
return {
k: ArraySpec(dtype=v.dtype.as_numpy_dtype, shape=tuple(v.shape))
for k, v in element_spec.items()
}
def save(self, filename: epath.Path):
if self._checkpoint:
self._ckpt.write(os.fspath(filename))
def restore(self, filename: epath.Path):
if self._checkpoint:
self._ckpt.read(os.fspath(filename)).assert_consumed()
class PeekableDatasetIterator(DatasetIterator):
"""Wraps a DatasetIterator to provide a peek() method.
This allows to look at the next element which can be useful in 2 scenarios:
a) Get the structure of elements if the element_spec property is not
supported.
b) Request the next element without consuming it. This is especially handy to
trigger reading of the first element while the model is being initialized.
Example use case:
>>> pool = clu.asynclib.Pool()
>>> @pool
>>> def warmup_input_pipeline():
>>> train_iter.peek()
>>> first_batch_ready = warmup_input_pipeline()
>>> # Do other stuff...
>>> first_batch_ready.result() # wait for input pipeline to be ready.
"""
def __init__(self, it: DatasetIterator):
self._it = it
# Mutex for self._it.
self._mutex = threading.Lock()
self._peek: Optional[Element] = None
self._pool = None
self._peek_future = None
def __next__(self) -> Element:
with self._mutex:
if self._peek is None:
return next(self._it)
peek = self._peek
self._peek = None
return peek
def reset(self):
with self._mutex:
self._it.reset()
self._peek = None
self._pool = None
self._peek_future = None
@property
def element_spec(self) -> ElementSpec:
return self._it.element_spec
def peek(self) -> Element:
"""Returns the next element without consuming it.
This will get the next element from the underlying iterator. The element
is stored and return on the next call of __next__().
Returns:
The next element.
"""
if self._peek is None:
self._peek = next(self)
return self._peek
def peek_async(self) -> concurrent.futures.Future[Element]:
"""Same as peek() but returns the Future of the element.
Users can call this to warm up the iterator.
Returns:
Future with the next element. The element is also kept and returned on the
next call of __next__().
"""
with self._mutex:
if self._peek_future is None:
if self._pool is None:
self._pool = asynclib.Pool(max_workers=1)
self._peek_future = self._pool(self.peek)()
return self._peek_future
def save(self, filename: epath.Path):
with self._mutex:
self._it.save(filename)
def restore(self, filename: epath.Path):
with self._mutex:
self._it.restore(filename)
================================================
FILE: clu/data/dataset_iterator_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset_iterator."""
import itertools
import pathlib
import tempfile
from absl.testing import parameterized
from clu.data import dataset_iterator
import numpy as np
import tensorflow as tf
INDEX = "_index"
class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):
def _create_iterator(self, start_index: int, checkpoint: bool = True):
"""Create an iterator over some prime numbers with index."""
primes = tf.constant([2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
ds = tf.data.Dataset.range(start_index, 10)
ds = ds.map(lambda i: {INDEX: i, "prime": primes[i]})
# Remove index 1 and 3.
ds = ds.filter(lambda x: tf.logical_and(x["prime"] != 3, x["prime"] != 7))
ds = ds.batch(2, drop_remainder=True)
return dataset_iterator.TfDatasetIterator(ds, checkpoint=checkpoint)
def test_tf_iterator(self):
it = self._create_iterator(0)
self.assertEqual(
it.element_spec, {
INDEX: dataset_iterator.ArraySpec(np.int64, (2,)),
"prime": dataset_iterator.ArraySpec(np.int32, (2,))
})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
it.reset()
# Iterator starts from the beginning.
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
def test_tf_iterator_save_and_load(self):
it = self._create_iterator(0)
next(it)
next(it)
next(it)
work_dir = pathlib.Path(tempfile.mkdtemp())
filename = work_dir / "ckpt"
it.save(filename)
self.assertTrue((work_dir / "ckpt.index").exists())
it = self._create_iterator(0)
# Iterator is at the beginning (batch 1).
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
it.load(filename)
# Iterator is at the end (batch 4).
self.assertEqual(next(it), {INDEX: [8, 9], "prime": [23, 29]})
def test_tf_iterator_save_and_load_no_checkpoint(self):
it = self._create_iterator(0, checkpoint=False)
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
work_dir = pathlib.Path(tempfile.mkdtemp())
filename = work_dir / "ckpt"
it.save(filename) # Should be a no-op and not create a checkpoint.
self.assertFalse((work_dir / "ckpt.index").exists())
it = self._create_iterator(0, checkpoint=False)
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
it.restore(filename) # Should be a no-op, iterator just continues.
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
def test_peekable_dataset_iterator(self):
it = self._create_iterator(0)
it = dataset_iterator.PeekableDatasetIterator(it)
self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
@parameterized.parameters(itertools.product([True, False], [True, False]))
def test_peekable_dataset_iterator_async(self, wait: bool, peek_first: bool):
it = self._create_iterator(0)
it = dataset_iterator.PeekableDatasetIterator(it)
future = it.peek_async()
self.assertIsNone(it._peek)
if wait:
future.result()
self.assertIsNotNone(it._peek)
if peek_first:
self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/deterministic_data.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Helper functions for building deterministic tf.data input pipelines.
The function `create_dataset()` makes it easy to build a `tf.data` based input
pipeline that allows for completely reproducible results based on a single
initial random seed. The caller must take care to create a unique initial seed
on every host that is then passed to `create_dataset()`, where further unique
random keys are derived for every batch. Within a single batch, this key is
exposed as the special feature "rng" and can be used to implement stateless
preprocessing functions.
The function `get_read_instruction_for_host()` makes it easy to split a dataset
evenly between multiple hosts in a SPMD setup with multiple machines. Within a
single host, every batch is usually distributed to all the attached accelerators
(the first value of the `batch_dims` argument to `create_dataset()`).
The function `create_distributed_dataset()` finally is intended to be used in
conjunction with a `tf.distribute.Strategy`.
Synopsis for deterministic training with multiple hosts:
import jax
from clu import deterministic_data
rng = jax.random.PRNGKey(42) # Global RNG (e.g. from config)
rng = jax.random.fold_in(rng, jax.process_index()) # Derive RNG for this host.
dataset_builder = tfds.builder(...)
split = deterministic_data.get_read_instruction_for_host(
"train", dataset_builder.info.splits["train"].num_examples)
ds = deterministic_data.create_dataset(
dataset_builder,
split=split,
rng=rng
)
ds_iter = iter(ds)
for _ in range(num_train_steps):
batch = jax.tree_util.tree_map(lambda x: x._numpy(), next(ds_iter)
# (training step)
"""
import enum
import functools
import operator
from typing import Callable, Dict, Optional, Sequence, Union
from absl import logging
import jax
import jax.numpy as jnp
import numpy as np
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds
import typing_extensions
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
Features = Dict[str, Tensor]
AUTOTUNE = tf.data.experimental.AUTOTUNE
_use_split_info = version.parse("4.4.0") < version.parse(
tfds.version.__version__)
class DatasetBuilder(typing_extensions.Protocol):
"""Protocol for dataset builders (subset of tfds.core.DatasetBuilder)."""
def as_dataset(
self, split: Union[str, tfds.core.ReadInstruction], shuffle_files: bool,
read_config: tfds.ReadConfig,
decoders: Optional[Dict[str, tfds.decode.Decoder]]) -> tf.data.Dataset:
...
class RemainderOptions(enum.Enum):
"""How to handle examples not divisible by number of processes.
Possible values:
- DROP: Examples not divisible by process count will be dropped. Every host
receives the same number of examples.
- BALANCE_ON_PROCESSES: Examples not divisible by process count will be
distributed evenly on processes, by increasing process number. For example,
if there are 4 processes and 7 examples, then processes 0, 1, 2 will have
2 examples, and process 3 will have 1 example.
- ON_FIRST_PROCESS: Examples not divisible by process count will be assigned
to process 0.
"""
DROP = 0
BALANCE_ON_PROCESSES = 1
ON_FIRST_PROCESS = 2
def _shard_read_instruction(
absolute_instruction,
*,
split_infos: Dict[str, Union[int, tfds.core.SplitInfo]],
host_id: int,
host_count: int,
remainder_options: RemainderOptions,
) -> tfds.core.ReadInstruction:
"""Shards a single ReadInstruction. See get_read_instruction_for_host()."""
start = absolute_instruction.from_ or 0
if _use_split_info:
end = absolute_instruction.to or (
split_infos[absolute_instruction.splitname].num_examples) # pytype: disable=attribute-error
else:
end = absolute_instruction.to or split_infos[absolute_instruction.splitname]
assert end >= start, f"start={start}, end={end}"
num_examples = end - start
examples_per_host = num_examples // host_count
shard_start = start + examples_per_host * host_id
shard_end = start + examples_per_host * (host_id + 1)
# Handle remaining examples.
num_unused_examples = num_examples % host_count
assert num_unused_examples >= 0, num_unused_examples
assert num_unused_examples < host_count, num_unused_examples
if num_unused_examples > 0:
if remainder_options == RemainderOptions.DROP:
logging.warning("Dropping %d examples of %d examples (host count: %d).",
num_unused_examples, num_examples, host_count)
elif remainder_options == RemainderOptions.BALANCE_ON_PROCESSES:
shard_start += min(host_id, num_unused_examples)
shard_end += min(host_id + 1, num_unused_examples)
elif remainder_options == RemainderOptions.ON_FIRST_PROCESS:
shard_end += num_unused_examples
if host_id > 0:
shard_start += num_unused_examples
else:
raise ValueError(f"Invalid remainder_options: {remainder_options}")
return tfds.core.ReadInstruction(
absolute_instruction.splitname,
from_=shard_start,
to=shard_end,
unit="abs")
_DEPRECATE_MSG = """
`get_read_instruction_for_host` is DEPRECATED.
Migration instruction: Use `tfds.split_for_jax_process` which is simpler
and nativelly supported by TFDS.
```
split = tfds.split_for_jax_process('train[75%:]', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
```
See: https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training
"""
def get_read_instruction_for_host(
split: str,
num_examples: Optional[int] = None,
*,
dataset_info: Optional[tfds.core.DatasetInfo] = None,
host_id: Optional[int] = None,
host_count: Optional[int] = None,
drop_remainder: Optional[bool] = None,
remainder_options: RemainderOptions = RemainderOptions.DROP
) -> tfds.core.ReadInstruction:
"""Returns a `ReadInstruction` of the data ranges for this host.
`get_read_instruction_for_host` is DEPRECATED. Please use
`tfds.split_for_jax_process` or `tfds.even_split`. See:
https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training
In a distributed setting all hosts should get the same number of examples.
This can exclude a few (< host_count) examples.
The examples are distributed evenly across the hosts, and remaining examples
are distributed to the hosts with the lowest id.
Assuming a single epoch, the number of batches (e.g. for
`create_dataset(pad_up_to_batches)`) can be computed by:
batches = int(np.ceil(num_examples / global_batch_size))
Args:
split: Name of the dataset split to use or TFDS spec (e.g.
`train[:800]+validation[:100]`). If you use the spec you must pass
dataset_info. For specs with multiple splits each split is sharded
independently of the other splits.
num_examples: Deprecated - use dataset_info instead. Number of examples of
the split.
dataset_info: TFDS dataset info; used to get the number of examples per
split.
host_id: Optional, host index in [0, host_count). Defaults to
`jax.process_index()`.
host_count: Optional, number of hosts. Defaults to `jax.host_count`.
drop_remainder: Deprecated - use remainder_options instead.
remainder_options: The options to handle the remaining examples.
"""
logging.warning(_DEPRECATE_MSG)
if num_examples is not None:
logging.warning(
"`num_examples` is deprecated. Please pass `dataset_info` instead.")
if drop_remainder is not None:
remainder_options = (
RemainderOptions.DROP
if drop_remainder else RemainderOptions.BALANCE_ON_PROCESSES)
logging.warning(
"`drop_remainder` is deprecated. Please pass `remainder_options` "
"instead. `remainder_options` is reset with %s.", remainder_options)
if dataset_info is None:
if any(special in split for special in ["[", "]", "+"]):
raise ValueError(
f"Sharding split {split} requires passing `dataset_info`.")
if host_id is None:
host_id = jax.process_index()
if host_count is None:
host_count = jax.process_count()
if host_id < 0 or host_id >= host_count or host_count < 1:
raise ValueError(
f"Invalid combination of host_id ({host_id}) and host_count "
f"({host_count}).")
if _use_split_info:
if dataset_info is None:
split_infos = {
split: tfds.core.SplitInfo(
name=split,
shard_lengths=[num_examples],
num_bytes=0,
),
}
else:
split_infos = dataset_info.splits
else:
if dataset_info is None:
split_infos = {split: num_examples}
else:
split_infos = {k: v.num_examples for k, v in dataset_info.splits.items()}
read_instruction = tfds.core.ReadInstruction.from_spec(split)
sharded_read_instructions = []
for ri in read_instruction.to_absolute(split_infos):
sharded_read_instructions.append(
_shard_read_instruction(
ri,
split_infos=split_infos,
host_id=host_id,
host_count=host_count,
remainder_options=remainder_options))
return functools.reduce(operator.add, sharded_read_instructions)
def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
preprocess_fn: Callable[[Features],
Features], *,
rng: jnp.ndarray) -> tf.data.Dataset:
"""Maps `ds` using the preprocess_fn and a deterministic RNG per example.
Args:
ds: Dataset containing Python dictionary with the features. The 'rng'
feature should not exist.
preprocess_fn: Preprocessing function that takes a Python dictionary of
tensors and returns a Python dictionary of tensors. The function should be
convertible into a TF graph.
rng: Base RNG to use. Per example RNGs will be derived from this by folding
in the example index.
Returns:
The dataset mapped by the `preprocess_fn`.
"""
def _fn(example_index: int, features: Features) -> Features:
example_index = tf.cast(example_index, tf.int32)
features["rng"] = tf.random.experimental.stateless_fold_in(
tf.cast(rng, tf.int64), example_index)
processed = preprocess_fn(features)
if isinstance(processed, dict) and "rng" in processed:
del processed["rng"]
return processed
return ds.enumerate().map(_fn, num_parallel_calls=AUTOTUNE)
def pad_dataset(dataset: tf.data.Dataset,
*,
batch_dims: Sequence[int],
pad_up_to_batches: Optional[int] = None,
cardinality: Optional[int] = None):
"""Adds padding to a dataset.
Args:
dataset: The dataset to be padded.
batch_dims: List of size of batch dimensions. Multiple batch dimension can
be used to provide inputs for multiple devices. E.g.
[jax.local_device_count(), batch_size // jax.device_count()].
pad_up_to_batches: Set this option to process the entire dataset. When set,
then the dataset is first padded to the specified number of batches. A new
feature called "mask" is added to every batch. This feature is set to
`True` for every example that comes from `dataset_builder`, and to `False`
for every example that is padded to get to the specified number of
batches. Note that the specified `dataset_builder` and `split` must result
in at least `pad_up_to_batches` (possibly partial) batches. If `None`,
derives from `batch_dims` and `cardinality` such that `pad_up_to_batches *
batch_dims == cardinality`. Note that `cardinality` is what you pass in,
not necessarily the original full dataset size if you decide to shard it
per host.
cardinality: Number of examples in the dataset. Only needed when the
cardinality cannot be retrieved via `ds.cardinalty()` (e.g. because of
using `ds.filter()`).
Returns:
The padded dataset, with the added feature "mask" that is set to `True` for
examples from the original `dataset` and to `False` for padded examples.
"""
if not isinstance(dataset.element_spec, dict):
raise ValueError("The dataset must have dictionary elements.")
if cardinality is None:
cardinality = dataset.cardinality()
if cardinality == tf.data.UNKNOWN_CARDINALITY:
raise ValueError(
"Cannot determine dataset cardinality. This can happen when you use "
"a `.filter()` on the dataset. Please provide the cardinality as an "
"argument to `create_dataset()`.")
if "mask" in dataset.element_spec:
raise ValueError("Dataset already contains a feature named \"mask\".")
if pad_up_to_batches is None:
pad_up_to_batches = int(np.ceil(cardinality / np.prod(batch_dims)))
filler_element = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
filler_element["mask"] = [False]
filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)
dataset = dataset.map(
lambda features: dict(mask=True, **features), num_parallel_calls=AUTOTUNE)
padding = pad_up_to_batches * np.prod(batch_dims) - int(cardinality)
assert padding >= 0, (
f"Invalid padding={padding} (batch_dims={batch_dims}, cardinality="
f"{cardinality}, pad_up_to_batches={pad_up_to_batches})")
return dataset.concatenate(filler_dataset.repeat(padding))
def create_dataset(dataset_builder: DatasetBuilder,
*,
split: Union[str, tfds.core.ReadInstruction],
batch_dims: Sequence[int] = (),
rng: Union[None, jnp.ndarray, tf.Tensor] = None,
filter_fn: Optional[Callable[[Features], bool]] = None,
preprocess_fn: Optional[Callable[[Features],
Features]] = None,
decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
cache: bool = False,
num_epochs: Optional[int] = None,
shuffle: bool = True,
shuffle_buffer_size: int = 10_000,
prefetch_size: int = 4,
pad_up_to_batches: Optional[Union[int, str]] = None,
cardinality: Optional[int] = None,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Creates standard input pipeline (shuffle, preprocess, batch).
Args:
dataset_builder: Dataset builder object with a as_dataset() method. E.g.
instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
split: Specifies which split of the data to load. Passed on to
`tfds.DatasetBuilder.as_dataset()`. See also the
[split API guide](https://www.tensorflow.org/datasets/splits). In a multi
host setup, this parameter can conveniently be generated by the function
`get_read_instruction_for_host()`.
batch_dims: List of size of batch dimensions. Multiple batch dimension can
be used to provide inputs for multiple devices. E.g.
[jax.local_device_count(), batch_size // jax.device_count()].
rng: A jax.random.PRNG key or a tf.Tensor for TF stateless seeds to use of
seeding shuffle operations and preprocessing ops. Must be set if
shuffling.
filter_fn: Optional function to filter the decoded examples. This happens
before the preprocessing.
preprocess_fn: Function for preprocessing individual examples (which should
be Python dictionary of tensors).
decoders: Optional dictionary of decoder passed to as_dataset.
cache: Whether to cache the unprocessed dataset in memory.
num_epochs: Number of epochs for which to repeat the dataset. None to repeat
forever.
shuffle: Whether to shuffle the dataset (both on file and example level).
shuffle_buffer_size: Number of examples in the shuffle buffer.
prefetch_size: The number of elements in the final dataset to prefetch in
the background. This should be a small (say <10) positive integer or
tf.data.experimental.AUTOTUNE.
pad_up_to_batches: Set this option to process the entire dataset. - If set
with an integer, the dataset is first padded to the specified number of
batches. A new feature called "mask" is added to every batch. This feature
is set to `True` for every example that comes from `dataset_builder`, and
to `False` for every example that is padded. Note that the specified
`dataset_builder` and `split` must result in at least `pad_up_to_batches`
(possibly partial) batches. - If set with "auto", derives from
`batch_dims` and `cardinality` such that `pad_up_to_batches * batch_dims
== cardinality`. - If `None`, the dataset won't be padded.
cardinality: Number of examples in the dataset. Only needed when
`pad_up_to_batches` is specified and the cardinality cannot be retrieved
via `ds.cardinalty()` (e.g. because of `ds.filter()`).
drop_remainder: Whether to drop remainders when batching.
Returns:
The dataset with preprocessed and batched examples.
"""
rng_available = rng is not None
if not rng_available and shuffle:
raise ValueError("Please set 'rng' when shuffling.")
if rng_available:
if isinstance(rng, tf.Tensor):
rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)]
else:
rngs = list(jax.random.key_data(jax.random.split(rng, 3)))
else:
rngs = 3 * [[None, None]]
dataset_options = tf.data.Options()
dataset_options.experimental_optimization.map_parallelization = True
dataset_options.threading.private_threadpool_size = 48
dataset_options.threading.max_intra_op_parallelism = 1
read_config = tfds.ReadConfig(
shuffle_seed=rngs.pop()[0], options=dataset_options)
ds = dataset_builder.as_dataset(
split=split,
shuffle_files=shuffle,
read_config=read_config,
decoders=decoders)
if filter_fn is not None:
ds = ds.filter(filter_fn)
if cache:
ds = ds.cache()
if shuffle:
ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0])
ds = ds.repeat(num_epochs)
if preprocess_fn is not None:
if rng_available:
ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop())
else:
ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE)
if pad_up_to_batches is not None:
assert isinstance(pad_up_to_batches, int) or pad_up_to_batches == "auto"
ds = pad_dataset(
ds,
batch_dims=batch_dims,
pad_up_to_batches=(None if pad_up_to_batches == "auto" else
pad_up_to_batches),
cardinality=cardinality)
if batch_dims:
for batch_size in reversed(batch_dims):
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return ds.prefetch(prefetch_size)
StrOrReadInstruction = Union[str, tfds.core.ReadInstruction]
def create_distributed_dataset(
dataset_builder,
*,
split: Union[StrOrReadInstruction, Callable[[int, int],
StrOrReadInstruction]],
global_batch_size: int,
strategy: tf.distribute.Strategy,
rng: Optional[tf.Tensor] = None,
filter_fn: Optional[Callable[[Features], bool]] = None,
preprocess_fn: Optional[Callable[[Features], Features]] = None,
decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
cache: bool = False,
num_epochs: Optional[int] = None,
shuffle: bool = True,
shuffle_buffer_size: int = 10_000,
prefetch_size: int = 4,
pad_up_to_batches: Optional[int] = None,
cardinality: Optional[int] = None,
drop_remainder: bool = True) -> tf.data.Dataset:
"""Creates standard input pipeline (shuffle, preprocess, batch).
Args:
dataset_builder: Dataset builder object with a as_dataset() method. E.g.
instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
split: Split name to use, will be passed to as_dataset(). To read different
data chunks on different replicas pass a callable that accepts the host_id
and host_count and returns a split name.
global_batch_size: Global batch size for all input pipelines together.
strategy: Distribution strategy for distributing the dataset.
rng: A tf.Tensor with a stateless random key to seed shuffle operations and
preprocessing ops.
filter_fn: Optional function to filter the decoded examples. This happens
before the preprocessing.
preprocess_fn: Function for preprocessing individual examples (which should
be Python dictionary of tensors)
decoders: Optional dictionary of decoder passed to as_dataset.
cache: Whether to cache the unprocessed dataset in memory.
num_epochs: Number of epochs for which to repeat the dataset. None to repeat
forever.
shuffle: Whether the shuffle the dataset (both on the file and example
level).
shuffle_buffer_size: Number of examples in the shuffle buffer.
prefetch_size: The number of elements in the final dataset to prefetch in
the background. This should be a small (say <10) positive integer or
tf.data.experimental.AUTOTUNE.
pad_up_to_batches: Set this option to process the entire dataset. When set,
then the dataset is first padded to the specified number of batches. A new
feature called "mask" is added to every batch. This feature is set to
`True` for every example that comes from `dataset_builder`, and to `False`
for every example that is padded to get to the specified number of
batches. Note that the specified `dataset_builder` and `split` must
provide at least `pad_up_to_batches` (possibly partial) batches.
cardinality: Number of examples in the dataset. Only needed when
`pad_up_to_batches` is specified and the cardinality cannot be retrieved
via `ds.cardinalty()` (e.g. because of `ds.filter()`).
drop_remainder: Whether to drop remainders when batching.
Returns:
The dataset with preprocessed and batched examples.
"""
def dataset_fn(input_context: tf.distribute.InputContext):
"""Returns the dataset for a single worker."""
logging.info("dataset_fn(input_context=%s)", input_context)
if rng is None:
local_rng = None
else:
local_rng = tf.random.experimental.stateless_fold_in(
rng, input_context.input_pipeline_id)
if callable(split):
local_split = split(input_context.input_pipeline_id,
input_context.num_input_pipelines)
else:
local_split = split
per_replica_batch_size = input_context.get_per_replica_batch_size(
global_batch_size)
return create_dataset(
dataset_builder=dataset_builder,
split=local_split,
batch_dims=[per_replica_batch_size],
rng=local_rng,
filter_fn=filter_fn,
preprocess_fn=preprocess_fn,
decoders=decoders,
cache=cache,
num_epochs=num_epochs,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
prefetch_size=prefetch_size,
pad_up_to_batches=pad_up_to_batches,
cardinality=cardinality,
drop_remainder=drop_remainder)
return strategy.distribute_datasets_from_function(dataset_fn)
================================================
FILE: clu/deterministic_data_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the deterministic_data module."""
import dataclasses
import itertools
import math
from typing import Dict
from unittest import mock
from absl.testing import parameterized
from clu import deterministic_data
import jax
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds
_use_split_info = version.parse("4.4.0") < version.parse(
tfds.version.__version__)
@dataclasses.dataclass
class MyDatasetBuilder:
name2len: Dict[str, int] # Number of examples per split.
def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool,
read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset:
del shuffle_files, read_config, decoders
if _use_split_info:
split_infos = {
k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
for k, v in self.name2len.items()
}
instructions = split.to_absolute(split_infos)
else:
instructions = split.to_absolute(self.name2len)
assert len(instructions) == 1
from_ = instructions[0].from_ or 0
to = instructions[0].to or self.name2len[instructions[0].splitname]
return tf.data.Dataset.range(from_, to).map(lambda i: {"index": i})
@dataclasses.dataclass
class FakeDatasetInfo:
train_size: int = 9
test_size: int = 8
@property
def splits(self):
return {
"train": tfds.core.SplitInfo("train", [self.train_size], 0),
"test": tfds.core.SplitInfo("test", [self.test_size], 0)
}
class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for deterministic_data module."""
@parameterized.parameters(
(9, 0, 1, True, "test[0:9]"),
(9, 0, 2, True, "test[0:4]"),
(9, 1, 2, True, "test[4:8]"), # Last example gets dropped.
(9, 0, 3, True, "test[0:3]"),
(9, 1, 3, True, "test[3:6]"),
(9, 2, 3, True, "test[6:9]"),
(9, 0, 1, False, "test[0:9]"),
(9, 0, 2, False, "test[0:5]"), # First host gets an extra example.
(9, 1, 2, False, "test[5:9]"),
(8, 0, 3, False, "test[0:3]"), # First 2 hosts get 1 example each.
(8, 1, 3, False, "test[3:6]"),
(8, 2, 3, False, "test[6:8]"),
)
def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
host_id: int,
host_count: int,
drop_remainder: bool,
expected_spec: str):
expected = tfds.core.ReadInstruction.from_spec(expected_spec)
actual = deterministic_data.get_read_instruction_for_host(
"test",
num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=drop_remainder)
if _use_split_info:
split_infos = {
"test": tfds.core.SplitInfo(
name="test",
shard_lengths=[9],
num_bytes=0,
)}
else:
split_infos = {"test": 9}
self.assertEqual(
expected.to_absolute(split_infos), actual.to_absolute(split_infos))
@parameterized.parameters(
# host_id, host_count, drop_remainder, spec, exected_spec_for_host
# train split has 9 examples.
(0, 1, True, "train", "train[0:9]"),
(0, 2, True, "train", "train[0:4]"),
(1, 2, True, "train", "train[4:8]"), # Last example gets dropped.
(0, 3, True, "train", "train[0:3]"),
(1, 3, True, "train", "train[3:6]"),
(2, 3, True, "train", "train[6:9]"),
(0, 1, False, "train", "train[0:9]"),
(0, 2, False, "train", "train[0:5]"), # First host gets an extra example.
(1, 2, False, "train", "train[5:9]"),
# test split has 8 examples.
(0, 3, False, "test", "test[0:3]"), # First 2 hosts get 1 example each.
(1, 3, False, "test", "test[3:6]"),
(2, 3, False, "test", "test[6:8]"),
# Subsplits.
(0, 2, True, "train[:50%]", "train[0:2]"),
(1, 2, True, "train[:50%]", "train[2:4]"),
(0, 2, True, "train[3:7]", "train[3:5]"),
(1, 2, True, "train[3:7]", "train[5:7]"),
(0, 2, True, "train[3:8]", "train[3:5]"), # Last example gets dropped.
(1, 2, True, "train[3:8]", "train[5:7]"),
# 2 splits.
(0, 2, True, "train[3:7]+test", "train[3:5]+test[0:4]"),
(1, 2, True, "train[3:7]+test", "train[5:7]+test[4:8]"),
# First host gets an extra example.
(0, 2, False, "train[3:8]+test[:5]", "train[3:6]+test[0:3]"),
(1, 2, False, "train[3:8]+test[:5]", "train[6:8]+test[3:5]"),
)
def test_get_read_instruction_for_host(self, host_id: int, host_count: int,
drop_remainder: bool, spec: str,
expected_spec_for_host: str):
actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
spec,
dataset_info=FakeDatasetInfo(),
host_id=host_id,
host_count=host_count,
drop_remainder=drop_remainder)
expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
expected_spec_for_host)
self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
@parameterized.parameters(
# host_id, host_count, balance_remainder, spec, exected_spec_for_host
# test split has 10 examples.
(0, 1, True, "test", "test[0:10]"),
(0, 1, False, "test", "test[0:10]"),
(0, 4, True, "test", "test[0:3]"),
(1, 4, True, "test", "test[3:6]"),
(2, 4, True, "test", "test[6:8]"),
(3, 4, True, "test", "test[8:10]"),
(0, 4, False, "test", "test[0:4]"),
(1, 4, False, "test", "test[4:6]"),
(2, 4, False, "test", "test[6:8]"),
(3, 4, False, "test", "test[8:10]"),
)
def test_get_read_instruction_balance_remainder(self, host_id: int,
host_count: int,
balance_remainder: bool,
spec: str,
expected_spec_for_host: str):
actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
spec,
dataset_info=FakeDatasetInfo(test_size=10),
host_id=host_id,
host_count=host_count,
remainder_options=deterministic_data.RemainderOptions
.BALANCE_ON_PROCESSES if balance_remainder else
deterministic_data.RemainderOptions.ON_FIRST_PROCESS)
expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
expected_spec_for_host)
self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))
@parameterized.parameters(
(0, 0), # No hosts.
(1, 1), # Only one host (host_id is zero-based.
(-1, 1), # Negative host_id.
(5, 2), # host_id bigger than number of hosts.
)
def test_get_read_instruction_for_host_fails(self, host_id: int,
host_count: int):
with self.assertRaises(ValueError):
deterministic_data.get_read_instruction_for_host(
"test", 11, host_id=host_id, host_count=host_count)
def test_preprocess_with_per_example_rng(self):
def preprocess_fn(features):
features["b"] = tf.random.stateless_uniform([], features["rng"])
return features
rng = jax.random.PRNGKey(42)
ds_in = tf.data.Dataset.from_tensor_slices({"a": [37.2, 31.2, 39.0]})
ds_out = deterministic_data._preprocess_with_per_example_rng(
ds_in, preprocess_fn, rng=rng)
self.assertAllClose([
{
"a": 37.2,
"b": 0.79542184
},
{
"a": 31.2,
"b": 0.45482683
},
{
"a": 39.0,
"b": 0.85335636
},
], list(ds_out))
@parameterized.parameters(*itertools.product([2, "auto"], [True, False]))
def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
dataset_builder = mock.Mock()
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
dataset_builder.as_dataset.return_value = dataset
batch_dims = (2, 5)
ds = deterministic_data.create_dataset(
dataset_builder,
split="(ignored)",
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=pad_up_to_batches,
cardinality=12 if cardinality else None,
)
ds_iter = iter(ds)
self.assertAllClose(
dict(
x=tf.ones((2, 5, 10)),
y=tf.ones((2, 5)),
mask=tf.ones((2, 5), bool),
), next(ds_iter))
self.assertAllClose(
dict(
x=tf.reshape(
tf.concat([tf.ones(
(2, 10)), tf.zeros((8, 10))], axis=0), (2, 5, 10)),
y=tf.reshape(tf.concat([tf.ones(2), tf.zeros(8)], axis=0), (2, 5)),
mask=tf.reshape(
tf.concat(
[tf.ones(2, bool), tf.zeros(8, bool)], axis=0), (2, 5)),
), next(ds_iter))
with self.assertRaises(StopIteration):
next(ds_iter)
def test_create_dataset_padding_raises_error_cardinality(self):
dataset_builder = mock.Mock()
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
dataset = dataset.filter(lambda x: True)
dataset_builder.as_dataset.return_value = dataset
batch_dims = (2, 5)
with self.assertRaisesRegex(
ValueError,
r"^Cannot determine dataset cardinality."):
deterministic_data.create_dataset(
dataset_builder,
split="(ignored)",
batch_dims=batch_dims,
num_epochs=1,
shuffle=False,
pad_up_to_batches=2,
cardinality=None,
)
def test_pad_dataset(self):
dataset = tf.data.Dataset.from_tensor_slices(
dict(x=tf.ones((12, 10)), y=tf.ones(12)))
padded_dataset = deterministic_data.pad_dataset(
dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
self.assertAllClose(
dict(
x=tf.concat([tf.ones(
(12, 10)), tf.zeros((8, 10))], axis=0),
y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0),
mask=tf.concat(
[tf.ones(12, bool), tf.zeros(8, bool)], axis=0)),
next(iter(padded_dataset.batch(20))))
def test_pad_nested_dataset(self):
dataset = tf.data.Dataset.from_tensor_slices(
{"x": {"z": (tf.ones((12, 10)), tf.ones(12))},
"y": tf.ones((12, 4))})
def expected(*dims):
return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0)
padded_dataset = deterministic_data.pad_dataset(
dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
self.assertAllClose(
{"x": {"z": (expected(10), expected())},
"y": expected(4),
"mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)},
next(iter(padded_dataset.batch(20))))
@parameterized.parameters(*itertools.product(range(20), range(1, 4)))
def test_same_cardinality_on_all_hosts(self, num_examples: int,
host_count: int):
builder = MyDatasetBuilder({"train": num_examples})
cardinalities = []
for host_id in range(host_count):
split = deterministic_data.get_read_instruction_for_host(
split="train",
num_examples=num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=True)
ds = deterministic_data.create_dataset(
builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1)
cardinalities.append(ds.cardinality().numpy().item())
self.assertLen(set(cardinalities), 1)
@parameterized.parameters(*itertools.product(range(20), range(1, 4)))
def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
host_count: int):
builder = MyDatasetBuilder({"train": num_examples})
# All hosts should have the same number of batches.
batch_size = 2
pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count)))
assert pad_up_to_batches * batch_size * host_count >= num_examples
cardinalities = []
for host_id in range(host_count):
split = deterministic_data.get_read_instruction_for_host(
split="train",
num_examples=num_examples,
host_id=host_id,
host_count=host_count,
drop_remainder=False)
ds = deterministic_data.create_dataset(
builder,
split=split,
batch_dims=[batch_size],
shuffle=False,
num_epochs=1,
pad_up_to_batches=pad_up_to_batches)
cardinalities.append(ds.cardinality().numpy().item())
self.assertLen(set(cardinalities), 1)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/internal/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: clu/internal/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Small utilities by CLU libraries."""
import contextlib
import sys
import time
from typing import Any, List, Mapping, Tuple, Union
from absl import logging
import jax.numpy as jnp
import numpy as np
import wrapt
@contextlib.contextmanager
def log_activity(activity_name: str):
"""Logs `activity_name` and timing information (or exception)."""
t0 = time.time()
logging.info("%s ...", activity_name)
try:
yield
finally:
dt = time.time() - t0
exc, *_ = sys.exc_info()
if exc is not None:
logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt,
exc.__name__)
else:
logging.info("%s finished after %.2fs.", activity_name, dt)
def logged_with(activity_name: str):
"""Returns a decorator wrapping a function with `log_activity()`."""
@wrapt.decorator
def decorator(wrapped, instance, args, kwargs):
del instance # not used
with log_activity(activity_name):
return wrapped(*args, **kwargs)
return decorator
def check_param(value, *, ndim=None, dtype=jnp.float32):
"""Raises a `ValueError` if `value` does not match ndim/dtype.
Args:
value: Value to be tested.
ndim: Expected dimensions.
dtype: Expected dtype.
Raises:
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
is not an instance of `jnp.ndarray`.
"""
if not isinstance(value, (np.ndarray, jnp.ndarray)):
raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
if ndim is not None and value.ndim != ndim:
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
if dtype is not None and value.dtype != dtype:
raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}")
def flatten_dict(
d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
) -> List[Tuple[str, Union[int, float, str]]]:
"""Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().
Args:
d: A dict-like object that has an `.item()` method.
prefix: Prefix to add to keys in `d`.
Returns:
Sequence of (k, v) pairs where k is the flattened key with individual
subkeys separated by dots. `None` values are replaced by the empty string.
"""
ret = []
for k, v in d.items():
# Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
if isinstance(v, Mapping) or hasattr(v, "items"):
ret += flatten_dict(v, prefix + (k,))
elif isinstance(v, (list, tuple)):
ret += flatten_dict({str(idx): value for idx, value in enumerate(v)},
prefix + (k,))
else:
ret.append((".".join(prefix + (k,)), v if v is not None else ""))
return ret
================================================
FILE: clu/internal/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from clu.internal import utils
import jax.numpy as jnp
import ml_collections
class TestError(BaseException):
__test__ = False
pass
class HelpersTest(absltest.TestCase):
def test_log_activity(
self,
):
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
pass
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
def test_log_activity_fails(
self,
):
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
with utils.log_activity("test_activity"):
raise TestError()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
def test_logged_with(self):
@utils.logged_with("test_activity")
def test():
pass
with self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^INFO:absl:test_activity finished after \d+.\d\ds.$")
def test_logged_with_fails(self):
@utils.logged_with("test_activity")
def test():
raise TestError()
with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long
with self.assertLogs() as logs:
test()
self.assertLen(logs.output, 2)
self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
self.assertRegex(logs.output[1],
r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
def test_check_param(self):
a = jnp.array(0.)
with self.assertRaisesRegex(ValueError, r"^Expected np.array or jnp.array"):
utils.check_param(None, ndim=1)
with self.assertRaisesRegex(ValueError, r"^Expected ndim"):
utils.check_param(a, ndim=1)
with self.assertRaisesRegex(ValueError, r"^Expected dtype"):
utils.check_param(a, ndim=0, dtype=jnp.int32)
utils.check_param(a, ndim=0) # should work
utils.check_param(a, ndim=0, dtype=jnp.float32) # should also work
def test_flatten_dict(self):
self.assertEqual(
utils.flatten_dict(
ml_collections.ConfigDict({
"x": 1,
"y": None,
"z": ml_collections.ConfigDict({
"a": "bc",
})
})), [("x", 1), ("y", ""), ("z.a", "bc")])
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/metric_writers/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metric writers write ML model outputs during model training and evaluation.
This module introduces the MetricWriter interface. MetricWriters allow users
to write out metrics about ML models during training and evaluation (e.g. loss,
accuracy).
There is a MetricWriter implementation for each back end (e.g. TensorFlow
summaries) and classes that work on top other MetricWriter to
write to multiple writes at once or write asynchronously.
Note: The current interface might not contain write() methods for all possible
data types. We are open for extending the interface to other data types
(e.g. audio).
Usage:
writer = MyMetricWriterImplementation()
# Before training.
writer.write_hparams({"learning_rate": 0.001, "batch_size": 64})
# Start training loop.
for step in range(num_train_steps):
loss = train_step()
if step % 50 == 0:
writer.write_scalars(step, {"loss": loss})
accuracy = evaluate()
writer.write_scalars(step, {"accuracy": accuracy})
# Make sure all values were written.
writer.flush() # or use metric_writers.ensure_flushes() context.
"""
# pylint: disable=unused-import
# pylint: disable=g-importing-member
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.async_writer import ensure_flushes
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.utils import create_default_writer
from clu.metric_writers.utils import write_values
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
================================================
FILE: clu/metric_writers/async_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes metrics in a separate thread.
- The order of the write calls is preserved.
- Users need to all `flush()` or use the `ensure_flushes()` context to make sure
that all metrics have been written.
- Errors while writing in the background thread will be re-raised in the main
thread on the next write_*() call.
"""
from collections.abc import Mapping, Sequence
import contextlib
from typing import Any, Optional
from clu import asynclib
from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import wrapt
Array = interface.Array
Scalar = interface.Scalar
@wrapt.decorator
def _wrap_exceptions(wrapped, instance, args, kwargs):
del instance
try:
return wrapped(*args, **kwargs)
except asynclib.AsyncError as e:
raise asynclib.AsyncError(
"Consider re-running the code without AsyncWriter (e.g. creating a "
"writer using "
"`clu.metric_writers.create_default_writer(asynchronous=False)`)"
) from e
class AsyncWriter(interface.MetricWriter):
"""MetricWriter that performs write operations in a separate thread.
All write operations will be executed in a background thread. If an exceptions
occurs in the background thread it will be raised on the main thread on the
call of one of the write_* methods.
Use num_workers > 1 at your own risk, if the underlying writer is not
thread-safe or does not expect out-of-order events, this can cause problems.
If num_workers is None then the ThreadPool will use `os.cpu_count()`
processes.
"""
def __init__(self,
writer: interface.MetricWriter,
*,
num_workers: Optional[int] = 1):
super().__init__()
self._writer = writer
# By default, we have a thread pool with a single worker to ensure that
# calls to the function are run in order (but in a background thread).
self._num_workers = num_workers
self._pool = asynclib.Pool(
thread_name_prefix="AsyncWriter", max_workers=num_workers)
@_wrap_exceptions
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
self._pool(self._writer.write_summaries)(
step=step, values=values, metadata=metadata)
@_wrap_exceptions
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
self._pool(self._writer.write_scalars)(step=step, scalars=scalars)
@_wrap_exceptions
def write_images(self, step: int, images: Mapping[str, Array]):
self._pool(self._writer.write_images)(step=step, images=images)
@_wrap_exceptions
def write_videos(self, step: int, videos: Mapping[str, Array]):
self._pool(self._writer.write_videos)(step=step, videos=videos)
@_wrap_exceptions
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
self._pool(self._writer.write_audios)(
step=step, audios=audios, sample_rate=sample_rate)
@_wrap_exceptions
def write_texts(self, step: int, texts: Mapping[str, str]):
self._pool(self._writer.write_texts)(step=step, texts=texts)
@_wrap_exceptions
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
self._pool(self._writer.write_histograms)(
step=step, arrays=arrays, num_buckets=num_buckets)
@_wrap_exceptions
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
self._pool(self._writer.write_pointcloud)(
step=step,
point_clouds=point_clouds,
point_colors=point_colors,
configs=configs,
)
@_wrap_exceptions
def write_hparams(self, hparams: Mapping[str, Any]):
self._pool(self._writer.write_hparams)(hparams=hparams)
def flush(self):
try:
self._pool.join()
finally:
self._writer.flush()
def close(self):
try:
self.flush()
finally:
self._writer.close()
class AsyncMultiWriter(multi_writer.MultiWriter):
"""AsyncMultiWriter writes to multiple writes in a separate thread."""
def __init__(self,
writers: Sequence[interface.MetricWriter],
*,
num_workers: Optional[int] = 1):
super().__init__([AsyncWriter(w, num_workers=num_workers) for w in writers])
@contextlib.contextmanager
def ensure_flushes(*writers: interface.MetricWriter):
"""Context manager which ensures that one or more writers are flushed."""
try:
# The caller should not need to use the yielded value, but we yield
# the first writer to stay backwards compatible for a single writer.
yield writers[0]
finally:
for writer in writers:
writer.flush()
================================================
FILE: clu/metric_writers/async_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for AsyncWriter."""
import time
from unittest import mock
from clu import asynclib
from clu.metric_writers import async_writer
from clu.metric_writers import interface
import numpy as np
import tensorflow as tf
class AsyncWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.sync_writer = mock.create_autospec(interface.MetricWriter)
self.writer = async_writer.AsyncWriter(self.sync_writer)
def test_write_summaries_async(self):
self.writer.write_summaries(
11,
{"a": np.eye(3, dtype=np.uint8),
"b": np.eye(2, dtype=np.float32)},
{"a": np.ones((2, 3)).tobytes()})
self.writer.flush()
self.sync_writer.write_summaries.assert_called_with(
step=11,
values={"a": mock.ANY, "b": mock.ANY},
metadata={"a": mock.ANY})
def test_write_scalars_async(self):
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
self.writer.flush()
self.sync_writer.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
def test_write_images(self):
images = np.zeros((2, 28, 28, 3))
self.writer.write_images(4, {"input_images": images})
self.writer.flush()
self.sync_writer.write_images.assert_called_with(4,
{"input_images": mock.ANY})
def test_write_videos(self):
videos = np.zeros((2, 4, 28, 28, 3))
self.writer.write_videos(4, {"input_videos": videos})
self.writer.flush()
self.sync_writer.write_videos.assert_called_with(4,
{"input_videos": mock.ANY})
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
self.sync_writer.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": mock.ANY},
point_colors={"pcd": mock.ANY},
configs={"config": mock.ANY},
)
def test_write_texts(self):
self.writer.write_texts(4, {"samples": "bla"})
self.writer.flush()
self.sync_writer.write_texts.assert_called_with(4, {"samples": "bla"})
def test_ensure_flushes(self):
with async_writer.ensure_flushes(self.writer) as writer:
writer.write_scalars(0, {"a": 3, "b": 0.15})
writer.write_scalars(2, {"a": 5, "b": 0.007})
self.sync_writer.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
self.sync_writer.flush.assert_called_once()
def test_ensure_flushes_with_multiple_writers(self):
sync_writer1 = mock.create_autospec(interface.MetricWriter)
writer1 = async_writer.AsyncWriter(sync_writer1)
sync_writer2 = mock.create_autospec(interface.MetricWriter)
writer2 = async_writer.AsyncWriter(sync_writer2)
with async_writer.ensure_flushes(writer1, writer2):
writer1.write_scalars(0, {"a": 3, "b": 0.15})
writer2.write_scalars(2, {"a": 5, "b": 0.007})
sync_writer1.write_scalars.assert_has_calls(
[mock.call(step=0, scalars={
"a": 3,
"b": 0.15
})])
sync_writer2.write_scalars.assert_has_calls(
[mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})])
sync_writer1.flush.assert_called_once()
sync_writer2.flush.assert_called_once()
def test_flush_before_close(self):
self.writer.close()
self.sync_writer.flush.assert_called()
self.sync_writer.close.assert_called()
def test_reraises_exception(self):
self.sync_writer.write_scalars.side_effect = ValueError("foo")
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
time.sleep(0.1)
with self.assertRaisesRegex(asynclib.AsyncError, "Consider re-running"):
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/interface.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for unify reporting model metrics across various logging formats.
This library provides a MetricWriter for each logging format (SummyWriter,
LoggingWriter, etc.) and composing MetricWriter to add support for asynchronous
logging or writing to multiple formats.
"""
import abc
from collections.abc import Mapping
from typing import Any, Optional, Union
import jax.numpy as jnp
import numpy as np
Array = Union[np.ndarray, jnp.ndarray]
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]
class MetricWriter(abc.ABC):
"""MetricWriter inferface."""
@abc.abstractmethod
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
"""Saves an arbitrary tensor summary.
Useful when working with custom plugins or constructing a summary directly.
Args:
step: Step at which the scalar values occurred.
values: Mapping from tensor keys to tensors.
metadata: Optional SummaryMetadata, as a proto or serialized bytes.
Note that markdown formatting is rendered by tensorboard.
"""
@abc.abstractmethod
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
"""Write scalar values for the step.
Consecutive calls to this method can provide different sets of scalars.
Repeated writes for the same metric at the same step are not allowed.
Args:
step: Step at which the scalar values occurred.
scalars: Mapping from metric name to value.
"""
@abc.abstractmethod
def write_images(self, step: int, images: Mapping[str, Array]):
"""Write images for the step.
Consecutive calls to this method can provide different sets of images.
Repeated writes for the same image key at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing images!
Args:
step: Step at which the images occurred.
images: Mapping from image key to images. Images should have the shape [N,
H, W, C] or [H, W, C], where H is the height, W is the width and C the
number of channels (1 or 3). N is the number of images that will be
written. Image dimensions can differ between different image keys but
not between different steps for the same image key.
"""
@abc.abstractmethod
def write_videos(self, step: int, videos: Mapping[str, Array]):
"""Write videos for the step.
Warning: Logging only.
Not all MetricWriter implementation support writing videos!
Consecutive calls to this method can provide different sets of videos.
Repeated writes for the same video key at the same step are not allowed.
Args:
step: Step at which the videos occurred.
videos: Mapping from video key to videos. videos should have the shape
[N, T, H, W, C] or [T, H, W, C], where T is time, H is the height,
W is the width and C the number of channels (1 or 3). N is the number
of videos that will be written. Video dimensions can differ between
different video keys but not between different steps for the same
video key.
"""
@abc.abstractmethod
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
"""Write audios for the step.
Consecutive calls to this method can provide different sets of audios.
Repeated writes for the same audio key at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing audios!
Args:
step: Step at which the audios occurred.
audios: Mapping from audio key to audios. Audios should have the shape
[N, T, C], where T is the time length and C the number of channels
(1 - mono, 2 - stereo, >= 3 - surround; not all writers support any
number of channels). N is the number of audios that will be written.
Audio dimensions can differ between different audio keys but not between
different steps for the same audio key. Values should be floating-point
values in [-1, +1].
sample_rate: Sample rate for the audios.
"""
@abc.abstractmethod
def write_texts(self, step: int, texts: Mapping[str, str]):
"""Writes text snippets for the step.
Warning: Not all MetricWriter implementation support writing text!
Args:
step: Step at which the text snippets occurred.
texts: Mapping from name to text snippet.
"""
@abc.abstractmethod
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
"""Writes histograms for the step.
Consecutive calls to this method can provide different sets of scalars.
Repeated writes for the same metric at the same step are not allowed.
Warning: Not all MetricWriter implementation support writing histograms!
Args:
step: Step at which the arrays were generated.
arrays: Mapping from name to arrays to summarize.
num_buckets: Number of buckets used to create the histogram of the arrays.
The default number of buckets depends on the particular implementation
of the MetricWriter.
"""
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
"""Writes point cloud summaries.
Args:
step: Step at which the point cloud was generated.
point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
array of point coordinates.
point_colors: Mapping from point colors key to [N, 3] array of point
colors.
configs: A dictionary of configuration options for the point cloud.
"""
raise NotImplementedError()
@abc.abstractmethod
def write_hparams(self, hparams: Mapping[str, Any]):
"""Write hyper parameters.
Do not call twice.
Args:
hparams: Flat mapping from hyper parameter name to value.
"""
@abc.abstractmethod
def flush(self):
"""Tells the MetricWriter to write out any cached values."""
@abc.abstractmethod
def close(self):
"""Flushes and closes the MetricWriter.
Calling any method on MetricWriter after MetricWriter.close()
is undefined behavior.
"""
================================================
FILE: clu/metric_writers/logging_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes all values to INFO log."""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.metric_writers import interface
import numpy as np
Array = interface.Array
Scalar = interface.Scalar
class LoggingWriter(interface.MetricWriter):
"""MetricWriter that writes all values to INFO log."""
def __init__(self, collection: Optional[str] = None):
if collection:
self._collection_str = f" collection={collection}"
else:
self._collection_str = ""
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
logging.info("[%d]%s Got raw tensors: %s.", step, self._collection_str,
{k: v.shape for k, v in values.items()})
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
values = [
f"{k}={v:.6g}" if isinstance(v, float) else f"{k}={v}"
for k, v in sorted(scalars.items())
]
logging.info("[%d]%s %s", step, self._collection_str, ", ".join(values))
def write_images(self, step: int, images: Mapping[str, Array]):
logging.info("[%d]%s Got images: %s.", step, self._collection_str,
{k: v.shape for k, v in images.items()})
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.info("[%d]%s Got videos: %s.", step, self._collection_str,
{k: v.shape for k, v in videos.items()})
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
logging.info("[%d]%s Got audios: %s.", step, self._collection_str,
{k: v.shape for k, v in audios.items()})
def write_texts(self, step: int, texts: Mapping[str, str]):
logging.info("[%d]%s Got texts: %s.", step, self._collection_str, texts)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
num_buckets = num_buckets or {}
for key, value in arrays.items():
histo, bins = _compute_histogram_as_tf(
np.asarray(value), num_buckets=num_buckets.get(key))
if histo is not None:
logging.info("[%d]%s Histogram for %r = {%s}", step,
self._collection_str, key,
_get_histogram_as_string(histo, bins))
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Any] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.info(
"[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
step,
self._collection_str,
{k: v.shape for k, v in point_clouds.items()},
(
{k: v.shape for k, v in point_colors.items()}
if point_colors is not None
else None
),
configs,
)
def write_hparams(self, hparams: Mapping[str, Any]):
logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)
def flush(self):
logging.flush()
def close(self):
self.flush()
def _compute_histogram_as_tf(
array: np.ndarray,
num_buckets: Optional[int] = None
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
"""Compute the histogram of the input array as TF would do.
Args:
array: Input data. The histogram is computed over the flattened array.
num_buckets: The number of equal-width bins used to create the histogram.
Returns:
histo: A numpy array with the values of the histogram.
bins: A numpy array with the bin edges (its length is length(histo)+1).
If the histogram cannot be built because the array is empty, returns
(None, None).
"""
# See DEFAULT_BUCKET_COUNT in tensorboard/plugins/histogram/summary_v2.py
num_buckets = num_buckets or 30
if num_buckets < 2:
logging.log_first_n(logging.WARNING,
"num_buckets was automatically changed from %d to 2", 1,
num_buckets)
num_buckets = 2
if array.size == 0:
return None, None
range_max = np.max(array)
range_min = np.min(array)
if np.isclose(range_max, range_min, rtol=1e-5, atol=1e-8):
histo = np.asarray([array.size], dtype=np.int64)
bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64)
else:
histo, bins = np.histogram(
array, bins=num_buckets, range=(range_min, range_max))
bins = np.asarray(bins, dtype=np.float64)
return histo, bins
def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):
# First items are right-open (i.e. [a, b)).
items = [
f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}"
for i, count in enumerate(histo[:-1])
]
# Last item is right-closed (i.e. [a, b]).
items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}")
return ", ".join(items)
================================================
FILE: clu/metric_writers/logging_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the LoggingWriter."""
from clu.metric_writers import logging_writer
import numpy as np
import tensorflow as tf
class LoggingWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.writer = logging_writer.LoggingWriter()
def test_write_scalars(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 0.0000005, "b": 0.007})
self.assertEqual(
logs.output,
["INFO:absl:[0] a=3, b=0.15", "INFO:absl:[2] a=5e-07, b=0.007"])
def test_write_images(self):
images = np.zeros((2, 28, 28, 3))
with self.assertLogs(level="INFO") as logs:
self.writer.write_images(4, {"input_images": images})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got images: {'input_images': (2, 28, 28, 3)}."])
def test_write_videos(self):
videos = np.zeros((2, 4, 28, 28, 3))
with self.assertLogs(level="INFO") as logs:
self.writer.write_videos(4, {"input_videos": videos})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got videos: {'input_videos': (2, 4, 28, 28, 3)}."])
def test_write_texts(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_texts(4, {"samples": "bla"})
self.assertEqual(
logs.output,
["INFO:absl:[4] Got texts: {'samples': 'bla'}."])
def test_write_histogram(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_histograms(
step=4,
arrays={
"a": np.asarray([-0.1, 0.1, 0.3]),
"b": np.arange(31),
"c": np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]),
},
num_buckets={
"a": 2,
"c": 1
})
# Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by
# default. Last bucket gets 2 values.
expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] +
["[29, 30]: 2"])
self.assertEqual(logs.output, [
"INFO:absl:[4] Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
f"INFO:absl:[4] Histogram for 'b' = {{{expected_histo_b}}}",
"WARNING:absl:num_buckets was automatically changed from 1 to 2",
"INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
])
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
with self.assertLogs(level="INFO") as logs:
self.writer.write_pointcloud(
step=4,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"configs": config},
)
self.assertEqual(
logs.output,
[
"INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
" point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
" {'material': 'PointCloudMaterial', 'size': 0.09}}."
],
)
def test_write_hparams(self):
with self.assertLogs(level="INFO") as logs:
self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
self.assertEqual(logs.output, [
"INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}"
])
def test_collection(self):
writer = logging_writer.LoggingWriter(collection="train")
with self.assertLogs(level="INFO") as logs:
writer.write_scalars(0, {"a": 3, "b": 0.15})
writer.write_images(4, {"input_images": np.zeros((2, 28, 28, 3))})
writer.write_texts(4, {"samples": "bla"})
writer.write_histograms(
step=4,
arrays={
"a": np.asarray([-0.1, 0.1, 0.3]),
},
num_buckets={
"a": 2,
})
writer.write_hparams({"learning_rate": 0.1})
self.assertEqual(logs.output, [
"INFO:absl:[0] collection=train a=3, b=0.15",
"INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.",
"INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.",
"INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
"INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}",
])
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/multi_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter that writes to multiple MetricWriters."""
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from clu.metric_writers import interface
Array = interface.Array
Scalar = interface.Scalar
class MultiWriter(interface.MetricWriter):
"""MetricWriter that writes to multiple writers at once."""
def __init__(self, writers: Sequence[interface.MetricWriter]):
self._writers = tuple(writers)
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
for w in self._writers:
w.write_summaries(step, values, metadata)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
for w in self._writers:
w.write_scalars(step, scalars)
def write_images(self, step: int, images: Mapping[str, Array]):
for w in self._writers:
w.write_images(step, images)
def write_videos(self, step: int, videos: Mapping[str, Array]):
for w in self._writers:
w.write_videos(step, videos)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
for w in self._writers:
w.write_audios(step, audios, sample_rate=sample_rate)
def write_texts(self, step: int, texts: Mapping[str, str]):
for w in self._writers:
w.write_texts(step, texts)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
for w in self._writers:
w.write_histograms(step, arrays, num_buckets)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
for w in self._writers:
w.write_pointcloud(
step, point_clouds, point_colors=point_colors, configs=configs
)
def write_hparams(self, hparams: Mapping[str, Any]):
for w in self._writers:
w.write_hparams(hparams)
def flush(self):
for w in self._writers:
w.flush()
def close(self):
for w in self._writers:
w.close()
================================================
FILE: clu/metric_writers/multi_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for MultiWriter."""
from unittest import mock
from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import numpy as np
import tensorflow as tf
class MultiWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.writers = [
mock.create_autospec(interface.MetricWriter),
mock.create_autospec(interface.MetricWriter)
]
self.writer = multi_writer.MultiWriter(self.writers)
def test_write_scalars(self):
self.writer.write_scalars(0, {"a": 3, "b": 0.15})
self.writer.write_scalars(2, {"a": 5, "b": 0.007})
self.writer.flush()
for w in self.writers:
w.write_scalars.assert_has_calls([
mock.call(step=0, scalars={
"a": 3,
"b": 0.15
}),
mock.call(step=2, scalars={
"a": 5,
"b": 0.007
})
])
w.flush.assert_called()
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
for w in self.writers:
w.write_pointcloud.assert_called_with(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
w.flush.assert_called()
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for writing to TF summary files."""
# pylint: disable=unused-import
from .tf.summary_writer import SummaryWriter
================================================
FILE: clu/metric_writers/tf/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Package __init__ file."""
================================================
FILE: clu/metric_writers/tf/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for writing to TF summary files.
Only works in eager mode. Does not work for Pytorch code, please use
TorchTensorboardWriter instead.
"""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.internal import utils
from clu.metric_writers import interface
from etils import epy
import tensorflow as tf
with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from tensorboard.plugins.hparams import api as hparams_api
from tensorboard.plugins.mesh import summary as mesh_summary # pylint: disable=line-too-long
# pylint: enable=g-import-not-at-top
Array = interface.Array
Scalar = interface.Scalar
class SummaryWriter(interface.MetricWriter):
"""MetricWriter that writes TF summary files."""
def __init__(self, logdir: str):
super().__init__()
self._summary_writer = tf.summary.create_file_writer(logdir)
def write_summaries(
self,
step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None,
):
with self._summary_writer.as_default():
for key, value in values.items():
md = metadata.get(key) if metadata is not None else None
tf.summary.write(key, value, step=step, metadata=md)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
with self._summary_writer.as_default():
for key, value in scalars.items():
tf.summary.scalar(key, value, step=step)
def write_images(self, step: int, images: Mapping[str, Array]):
with self._summary_writer.as_default():
for key, value in images.items():
if len(value.shape) == 3:
value = value[None]
tf.summary.image(key, value, step=step, max_outputs=value.shape[0])
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.log_first_n(
logging.WARNING,
"SummaryWriter does not support writing videos.", 1)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
with self._summary_writer.as_default():
for key, value in audios.items():
tf.summary.audio(key, value, sample_rate=sample_rate, step=step,
max_outputs=value.shape[0])
def write_texts(self, step: int, texts: Mapping[str, str]):
with self._summary_writer.as_default():
for key, value in texts.items():
tf.summary.text(key, value, step=step)
def write_histograms(
self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None,
):
with self._summary_writer.as_default():
for key, value in arrays.items():
buckets = None if num_buckets is None else num_buckets.get(key)
tf.summary.histogram(key, value, step=step, buckets=buckets)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
with self._summary_writer.as_default():
for key, vertices in point_clouds.items():
colors = None if point_colors is None else point_colors.get(key)
config = None if configs is None else configs.get(key)
mesh_summary.mesh(
key,
vertices=vertices,
colors=colors,
step=step,
config_dict=config,
)
def write_hparams(self, hparams: Mapping[str, Any]):
with self._summary_writer.as_default():
hparams_api.hparams(dict(utils.flatten_dict(hparams)))
def flush(self):
self._summary_writer.flush()
def close(self):
self._summary_writer.close()
================================================
FILE: clu/metric_writers/tf/summary_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for SummaryWriter."""
import collections
import os
from clu.metric_writers.tf import summary_writer
import numpy as np
import tensorflow as tf
from tensorboard.plugins.hparams import plugin_data_pb2
def _load_summaries_data(logdir):
"""Loads raw summaries data from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
metadata = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
if value.HasField("metadata"):
metadata[event.step][value.tag] = value.metadata.SerializeToString()
return data, metadata
def _load_histograms_data(logdir):
"""Loads tensor summaries from events in a logdir."""
# Note: new versions of histograms don't use the HistogramProto type, but
# they are written as tensors representing the bounds and counts of buckets,
# with plugin_name = "histogram".
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = {}
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
current_steps, current_tensors = data.get(value.tag, ([], []))
data[value.tag] = (current_steps + [event.step],
current_tensors + [tf.make_ndarray(value.tensor)])
return {
tag: (np.stack(steps), np.stack(tensors))
for tag, (steps, tensors) in data.items()
}
def _load_scalars_data(logdir: str):
"""Loads scalar summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = tf.make_ndarray(value.tensor).flat[0]
return data
def _load_pointcloud_data(logdir: str):
"""Loads pointcloud summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
if value.metadata.plugin_data.plugin_name == "mesh":
if "config" not in value.tag:
data[event.step][value.tag] = tf.make_ndarray(value.tensor)
else:
data[event.step][value.tag] = value.metadata.plugin_data.content
return data
def _load_hparams(logdir: str):
"""Loads hparams summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
# data = collections.defaultdict(dict)
hparams = []
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
if value.metadata.plugin_data.plugin_name == "hparams":
hparams.append(plugin_data_pb2.HParamsPluginData.FromString(
value.metadata.plugin_data.content))
return hparams
class SummaryWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.logdir = self.get_temp_dir()
self.writer = summary_writer.SummaryWriter(self.logdir)
def test_write_summaries(self):
self.writer.write_summaries(
11,
{"a": np.eye(3, dtype=np.uint8),
"b": np.eye(2, dtype=np.float32)},
{"a": np.ones((2, 3)).tobytes()})
self.writer.flush()
data, metadata = _load_summaries_data(self.logdir)
self.assertAllClose(
data[11],
{"a": np.eye(3, dtype=np.uint8), "b": np.eye(2, dtype=np.float32)})
self.assertIn("a", metadata[11])
def test_write_scalar(self):
self.writer.write_scalars(11, {"a": 0.6, "b": 15})
self.writer.write_scalars(20, {"a": 0.8, "b": 12})
self.writer.flush()
data = _load_scalars_data(self.logdir)
self.assertAllClose(data[11], {"a": 0.6, "b": 15})
self.assertAllClose(data[20], {"a": 0.8, "b": 12})
def test_write_histograms(self):
self.writer.write_histograms(
0, {
"a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
"b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
}, num_buckets={"a": 2, "b": 2})
self.writer.write_histograms(
2, {
"a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
"b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
}, num_buckets={"a": 2, "b": 2})
self.writer.flush()
data = _load_histograms_data(self.logdir)
# In the histograms, each tuple represents
# (bucket_min, bucket_max, bucket_count), where bucket_min is inclusive and
# bucket_max is exclusive (except the last bucket_max which is inclusive).
expected_histograms_a = [
# Step 0.
[(0.1, 0.4, 3), (0.4, 0.7, 2)],
# Step 1.
[(-0.1, 0.2, 2), (0.2, 0.5, 3)],
]
self.assertAllClose(data["a"], ([0, 2], expected_histograms_a))
expected_histograms_b = [
# Step 0.
[(-0.1, 0.15, 1), (0.15, 0.4, 4)],
# Step 1.
[(0.0, 0.35, 4), (0.35, 0.7, 1)],
]
self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))
def test_write_pointcloud(self):
point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
config = {
"material": "PointCloudMaterial",
"size": 0.09,
}
self.writer.write_pointcloud(
step=0,
point_clouds={"pcd": point_clouds},
point_colors={"pcd": point_colors},
configs={"config": config},
)
self.writer.flush()
data = _load_pointcloud_data(self.logdir)
self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
self.assertAllClose(data[0]["pcd_COLOR"], point_colors)
def test_hparams(self):
self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
hparams = _load_hparams(self.logdir)
self.assertLen(hparams, 1)
hparams_dict = hparams[0].session_start_info.hparams
self.assertLen(hparams_dict, 2)
self.assertEqual(512, hparams_dict["batch_size"].number_value)
self.assertEqual(90, hparams_dict["num_epochs"].number_value)
def test_hparams_nested(self):
config = {
"list": [1, 2],
"tuple": (3, 4),
"subconfig": {
"value": "a",
"list": [10, 20],
},
}
self.writer.write_hparams(config)
hparams = _load_hparams(self.logdir)
self.assertLen(hparams, 1)
hparams_dict = hparams[0].session_start_info.hparams
self.assertLen(hparams_dict, 7)
self.assertEqual(1, hparams_dict["list.0"].number_value)
self.assertEqual(2, hparams_dict["list.1"].number_value)
self.assertEqual(3, hparams_dict["tuple.0"].number_value)
self.assertEqual(4, hparams_dict["tuple.1"].number_value)
self.assertEqual("a", hparams_dict["subconfig.value"].string_value)
self.assertEqual(10, hparams_dict["subconfig.list.0"].number_value)
self.assertEqual(20, hparams_dict["subconfig.list.1"].number_value)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/torch_tensorboard_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MetricWriter for Pytorch summary files.
Use this writer for the Pytorch-based code.
"""
from collections.abc import Mapping
from typing import Any, Optional
from absl import logging
from clu.metric_writers import interface
from torch.utils import tensorboard
Array = interface.Array
Scalar = interface.Scalar
class TorchTensorboardWriter(interface.MetricWriter):
"""MetricWriter that writes Pytorch summary files."""
def __init__(self, logdir: str):
super().__init__()
self._writer = tensorboard.SummaryWriter(log_dir=logdir)
def write_summaries(
self, step: int,
values: Mapping[str, Array],
metadata: Optional[Mapping[str, Any]] = None):
logging.log_first_n(
logging.WARNING,
"TorchTensorboardWriter does not support writing raw summaries.", 1)
def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
for key, value in scalars.items():
self._writer.add_scalar(key, value, global_step=step)
def write_images(self, step: int, images: Mapping[str, Array]):
for key, value in images.items():
self._writer.add_image(key, value, global_step=step, dataformats="HWC")
def write_videos(self, step: int, videos: Mapping[str, Array]):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing videos.", 1)
def write_audios(
self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
for key, value in audios.items():
self._writer.add_audio(
key, value, global_step=step, sample_rate=sample_rate)
def write_texts(self, step: int, texts: Mapping[str, str]):
raise NotImplementedError(
"TorchTensorBoardWriter does not support writing texts."
)
def write_histograms(self,
step: int,
arrays: Mapping[str, Array],
num_buckets: Optional[Mapping[str, int]] = None):
for tag, values in arrays.items():
bins = None if num_buckets is None else num_buckets.get(tag)
self._writer.add_histogram(
tag, values, global_step=step, bins="auto", max_bins=bins)
def write_pointcloud(
self,
step: int,
point_clouds: Mapping[str, Array],
*,
point_colors: Mapping[str, Array] | None = None,
configs: Mapping[str, str | float | bool | None] | None = None,
):
logging.log_first_n(
logging.WARNING,
"TorchTensorBoardWriter does not support writing point clouds.",
1,
)
def write_hparams(self, hparams: Mapping[str, Any]):
self._writer.add_hparams(hparams, {})
def flush(self):
self._writer.flush()
def close(self):
self._writer.close()
================================================
FILE: clu/metric_writers/torch_tensorboard_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TorchTensorboardWriter."""
import collections
import os
from typing import Any, Dict
from clu.metric_writers import torch_tensorboard_writer
import numpy as np
import tensorflow as tf
def _load_scalars_data(logdir: str):
"""Loads scalar summaries from events in a logdir."""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = collections.defaultdict(dict)
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
for value in event.summary.value:
data[event.step][value.tag] = value.simple_value
return data
def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
"""Loads histograms summaries from events in a logdir.
Args:
logdir: a directory to find logs
Returns:
A generated histograms in a shape step -> tag -> histo.
"""
paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
data = {}
for path in paths:
for event in tf.compat.v1.train.summary_iterator(path):
if event.step not in data:
data[event.step] = {}
step_data = {}
for value in event.summary.value:
print(" value:", value)
step_data[value.tag] = value.histo
data[event.step].update(step_data)
return data
class TorchTensorboardWriterTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.logdir = self.get_temp_dir()
self.writer = torch_tensorboard_writer.TorchTensorboardWriter(self.logdir)
def test_write_scalar(self):
self.writer.write_scalars(11, {"a": 0.6, "b": 15})
self.writer.write_scalars(20, {"a": 0.8, "b": 12})
self.writer.flush()
data = _load_scalars_data(self.logdir)
self.assertAllClose(data[11], {"a": 0.6, "b": 15})
self.assertAllClose(data[20], {"a": 0.8, "b": 12})
def test_write_histograms(self):
self.writer.write_histograms(
0, {
"a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
"b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
}, num_buckets={"a": 2, "b": 2})
self.writer.write_histograms(
2, {
"a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
"b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
}, num_buckets={"a": 2, "b": 2})
self.writer.flush()
data = _load_histograms_data(self.logdir)
self.assertNear(data[0]["a"].min, 0.1, 0.001)
self.assertNear(data[0]["a"].max, 0.7, 0.001)
self.assertNear(data[0]["b"].min, -0.1, 0.001)
self.assertNear(data[0]["b"].max, 0.4, 0.001)
self.assertNear(data[2]["a"].min, -0.1, 0.001)
self.assertNear(data[2]["a"].max, 0.5, 0.001)
self.assertNear(data[2]["b"].min, 0.0, 0.001)
self.assertNear(data[2]["b"].max, 0.7, 0.001)
if __name__ == "__main__":
tf.test.main()
================================================
FILE: clu/metric_writers/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a generic write interface.
The write helper accepts a MetricWriter object and a Mapping[str,
clu.metrics.Metric], and automatically writes to the appropriate typed write
method of the writer depending on the type of the metric.
"""
# pylint: disable=g-importing-member
import collections
import getpass
import os
import re
from typing import Any, List, Mapping, Optional, Tuple, Union
from absl import flags
from absl import logging
from clu import values
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from etils import epath
import jax.numpy as jnp
import numpy as np
FLAGS = flags.FLAGS
def _is_scalar(value: Any) -> bool:
if isinstance(value, values.Scalar) or isinstance(
value, (int, float, np.number)
):
return True
if isinstance(value, (np.ndarray, jnp.ndarray)):
return value.ndim == 0 or value.size <= 1
return False
def write_values(
writer: MetricWriter,
step: int,
metrics: Mapping[
str, Union[values.Value, values.ArrayType, values.ScalarType]
],
):
"""Writes all provided metrics.
Allows providing a mapping of name to Value object, where each Value
specifies a type. The appropriate write method can then be called depending
on the type.
Args:
writer: MetricWriter object
step: Step at which the arrays were generated.
metrics: Mapping from name to clu.values.Value object.
"""
writes = collections.defaultdict(dict)
histogram_num_buckets = collections.defaultdict(int)
for k, v in metrics.items():
if isinstance(v, values.Summary):
writes[
(writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
][k] = v.value
elif _is_scalar(v):
if isinstance(v, values.Scalar):
writes[(writer.write_scalars, frozenset())][k] = v.value
else:
writes[(writer.write_scalars, frozenset())][k] = v
elif isinstance(v, values.Image):
writes[(writer.write_images, frozenset())][k] = v.value
elif isinstance(v, values.Text):
writes[(writer.write_texts, frozenset())][k] = v.value
elif isinstance(v, values.HyperParam):
writes[(writer.write_hparams, frozenset())][k] = v.value
elif isinstance(v, values.Histogram):
writes[(writer.write_histograms, frozenset())][k] = v.value
histogram_num_buckets[k] = v.num_buckets
elif isinstance(v, values.Audio):
writes[(
writer.write_audios,
frozenset({"sample_rate": v.sample_rate}.items()),
)][k] = v.value
else:
raise ValueError("Metric: ", k, " has unsupported value: ", v)
for (fn, extra_args), vals in writes.items():
if fn == writer.write_histograms:
# for write_histograms, the num_buckets arg is a Dict indexed by name
writer.write_histograms(step, vals, num_buckets=histogram_num_buckets)
else:
fn(step, vals, **dict(extra_args))
def create_default_writer(
logdir: Optional[epath.PathLike] = None,
*,
just_logging: bool = False,
asynchronous: bool = True,
collection: Optional[str] = None,
) -> MultiWriter:
"""Create the default writer for the platform.
On most platforms this will create a MultiWriter that writes to multiple back
ends (logging, TF summaries etc.).
Args:
logdir: Logging dir to use for TF summary files. If empty/None will the
returned writer will not write TF summary files.
just_logging: If True only use a LoggingWriter. This is useful in multi-host
setups when only the first host should write metrics and all other hosts
should only write to their own logs.
default (None) will automatically determine if you # GOOGLE-INTERNAL have
asynchronous: If True return an AsyncMultiWriter to not block when writing
metrics.
collection: A string which, if provided, provides an indication that the
provided metrics should all be written to the same collection, or
grouping.
Returns:
A `MetricWriter` according to the platform and arguments.
"""
if just_logging:
if asynchronous:
return AsyncMultiWriter([LoggingWriter(collection=collection)])
else:
return MultiWriter([LoggingWriter(collection=collection)])
writers = [LoggingWriter(collection=collection)]
if logdir is not None:
logdir = epath.Path(logdir)
if collection is not None:
logdir /= collection
writers.append(SummaryWriter(os.fspath(logdir)))
if asynchronous:
return AsyncMultiWriter(writers)
return MultiWriter(writers)
================================================
FILE: clu/metric_writers/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for interface."""
# pylint: disable=g-importing-member
import itertools
from typing import Any
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from clu import values
from clu.metric_writers import utils
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
import clu.metrics
import flax.struct
import jax.numpy as jnp
import tensorflow as tf
@flax.struct.dataclass
class HistogramMetric(clu.metrics.Metric):
value: jnp.ndarray
num_buckets: int
def compute_value(self):
return values.Histogram(self.value, self.num_buckets)
@flax.struct.dataclass
class ImageMetric(clu.metrics.Metric):
value: jnp.ndarray
def compute_value(self):
return values.Image(self.value)
@flax.struct.dataclass
class AudioMetric(clu.metrics.Metric):
value: jnp.ndarray
sample_rate: int
def compute_value(self):
return values.Audio(self.value, self.sample_rate)
@flax.struct.dataclass
class TextMetric(clu.metrics.Metric):
value: str
def compute_value(self):
return values.Text(self.value)
@flax.struct.dataclass
class HyperParamMetric(clu.metrics.Metric):
value: float
def compute_value(self):
return values.HyperParam(self.value)
@flax.struct.dataclass
class SummaryMetric(clu.metrics.Metric):
value: jnp.ndarray
metadata: Any
def compute_value(self):
return values.Summary(self.value, self.metadata)
def _to_summary(metrics):
return {k: v.value for k, v in metrics.items()}
def _to_list_of_dicts(d):
return [{k: v} for k, v in d.items()]
class ONEOF(object):
"""ONEOF(options_list) check value in options_list."""
def __init__(self, container):
if not hasattr(container, "__contains__"):
raise TypeError(f"{container!r} is not a container")
if not container:
raise ValueError(f"{container!r} is empty")
self._c = container
def __eq__(self, o):
return o in self._c
def __ne__(self, o):
return o not in self._c
def __repr__(self):
return "<ONEOF({})>".format(",".join(repr(i) for i in self._c))
class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):
def test_write(self):
writer = mock.Mock(spec_set=MetricWriter)
step = 3
num_buckets = 4
sample_rate = 10
scalar_metrics = {
"loss": clu.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])),
"accuracy": clu.metrics.LastValue.from_model_output(jnp.asarray([5])),
}
image_metrics = {
"image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
}
histogram_metrics = {
"hist": HistogramMetric(
value=jnp.asarray([7, 8]), num_buckets=num_buckets
),
"hist2": HistogramMetric(
value=jnp.asarray([9, 10]), num_buckets=num_buckets
),
}
audio_metrics = {
"audio": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate
),
"audio2": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
),
}
text_metrics = {
"text": TextMetric(value="hello"),
}
hparam_metrics = {
"lr": HyperParamMetric(value=0.01),
}
summary_metrics = {
"summary": SummaryMetric(
value=jnp.asarray([2, 3, 10]), metadata="some info"
),
"summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
}
metrics = {
**scalar_metrics,
**image_metrics,
**histogram_metrics,
**audio_metrics,
**text_metrics,
**hparam_metrics,
**summary_metrics,
}
metrics = {k: m.compute_value() for k, m in metrics.items()}
utils.write_values(writer, step, metrics)
writer.write_scalars.assert_called_once_with(
step, {k: m.compute() for k, m in scalar_metrics.items()}
)
writer.write_images.assert_called_once_with(
step, _to_summary(image_metrics)
)
writer.write_histograms.assert_called_once_with(
step,
_to_summary(histogram_metrics),
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
)
writer.write_audios.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
sample_rate=ONEOF([sample_rate, sample_rate + 2]),
)
writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
writer.write_hparams.assert_called_once_with(
step, _to_summary(hparam_metrics)
)
writer.write_summaries.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
metadata=ONEOF(["some info", 5]),
)
def test_create_default_writer_summary_writer_is_added(self):
writer = utils.create_default_writer(
logdir=self.get_temp_dir(), asynchronous=False
)
self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
writer = utils.create_default_writer(logdir=None, asynchronous=False)
self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))
if __name__ == "__main__":
absltest.main()
================================================
FILE: clu/metrics.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functional metric computation library.
This library defines a functional metric computation interface `Metric` that
relies on metrics accumulating intermediate values (in a possibly distributed
manner), and then computes the final metric value from these intermediate
values. Note that most metrics can be created via `Average.from_fun()`. See also
`CollectingMetric` that collects all model outputs with a given name and lends
itself to metric computation in Python.
Some common metrics, such as accuracy and loss average/standard deviation, and a
`Collection` with the same interface, are also provided.
The "model output" is a dictionary of values with unique keys that all have a
specific meaning (such as `loss`, `logits`, and `labels`) and every metric
depends on at least one such model output by name. These outputs are usually
expected to be instances of `jnp.ndarray`.
Synopsis:
# Note: Metrics do *not* work with `from __future__ import annotations`
from clu import metrics
import flax
import jax
@flax.struct.dataclass # required for jax.tree_*
class MyCollection(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
loss_std: metrics.Std.from_output("loss")
@jax.pmap
def eval_step(variables, metrics, inputs, labels):
loss, logits = get_loss_and_logits(variables, inputs, labels)
return metrics.merge(MyCollection.gather_from_model_output(
loss=loss, logits=logits, labels=labels))
def evaluate(variables_p, test_ds):
metrics = MyCollection.empty()
for inputs, labels in test_ds:
metrics = eval_step(variables_p, metrics, inputs, labels)
return metrics.unreplicate().compute()
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import inspect
from typing import Any, TypeVar, Protocol
from absl import logging
from clu.internal import utils
import clu.values
import flax
import jax
import jax.numpy as jnp
import numpy as np
Array = jax.Array
ArrayLike = jax.typing.ArrayLike
class FromFunCallable(Protocol):
"""The type of functions that can be passed to `Metrics.from_fun()`."""
def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
"""Returns the argument/arguments passed to the base from_model_output()."""
# TODO(b/200953513): Migrate away from logging imports (on module level)
# to logging the actual usage. See b/200953513.
def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
"""Raises a `ValueError` if shapes of `a` and `b` don't match."""
if a.shape != b.shape:
raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")
M = TypeVar("M", bound="Metric")
class Metric:
"""Interface for computing metrics from intermediate values.
Refer to `Collection` for computing multiple metrics at the same time.
Synopsis:
import jax.numpy as jnp
import flax
@flax.struct.dataclass
class Average(Metric):
total: jnp.ndarray
count: jnp.ndarray
@classmethod
def from_model_output(cls, value: jnp.ndarray, **_) -> Metric:
return cls(total=value.sum(), count=np.prod(value.shape))
def merge(self, other: Metric) -> Metric:
return type(self)(
total=self.total + other.total,
count=self.count + other.count,
)
def compute(self):
return self.total / self.count
average = None
for value in range(data):
update = Average.from_model_output(value)
average = update if average is None else average.merge(update)
print(average.compute())
"""
@classmethod
def from_model_output(cls: type[M], *args, **kwargs) -> M:
"""Creates a `Metric` from model outputs."""
raise NotImplementedError("Must override from_model_output()")
def merge(self: M, other: M) -> M:
"""Returns `Metric` that is the accumulation of `self` and `other`.
Args:
other: A `Metric` whose intermediate values should be accumulated onto the
values of `self`. Note that in a distributed setting, `other` will
typicall
gitextract_g9bj2g2j/ ├── .github/ │ └── workflows/ │ ├── build.yml │ └── python-publish.yml ├── AUTHORS ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── clu/ │ ├── __init__.py │ ├── asynclib.py │ ├── asynclib_test.py │ ├── checkpoint.py │ ├── checkpoint_test.py │ ├── data/ │ │ ├── __init__.py │ │ ├── dataset_iterator.py │ │ └── dataset_iterator_test.py │ ├── deterministic_data.py │ ├── deterministic_data_test.py │ ├── internal/ │ │ ├── __init__.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── metric_writers/ │ │ ├── __init__.py │ │ ├── async_writer.py │ │ ├── async_writer_test.py │ │ ├── interface.py │ │ ├── logging_writer.py │ │ ├── logging_writer_test.py │ │ ├── multi_writer.py │ │ ├── multi_writer_test.py │ │ ├── summary_writer.py │ │ ├── tf/ │ │ │ ├── __init__.py │ │ │ ├── summary_writer.py │ │ │ └── summary_writer_test.py │ │ ├── torch_tensorboard_writer.py │ │ ├── torch_tensorboard_writer_test.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── metrics.py │ ├── metrics_test.py │ ├── parameter_overview.py │ ├── parameter_overview_test.py │ ├── periodic_actions.py │ ├── periodic_actions_test.py │ ├── platform/ │ │ ├── __init__.py │ │ ├── interface.py │ │ └── local.py │ ├── preprocess_spec.py │ ├── preprocess_spec_test.py │ ├── profiler.py │ ├── run_pytest.google.sh │ └── values.py ├── clu_synopsis.ipynb └── setup.py
SYMBOL INDEX (535 symbols across 36 files)
FILE: clu/asynclib.py
class AsyncError (line 27) | class AsyncError(Exception):
class Pool (line 31) | class Pool:
method __init__ (line 50) | def __init__(self, thread_name_prefix: str = "",
method _reraise (line 69) | def _reraise(self) -> None:
method close (line 76) | def close(self) -> None:
method join (line 81) | def join(self) -> None:
method queue_length (line 98) | def queue_length(self) -> int:
method has_errors (line 103) | def has_errors(self) -> bool:
method clear_errors (line 107) | def clear_errors(self) -> List[Exception]:
method __call__ (line 113) | def __call__(self, fn: Callable): # pylint: disable=g-bare-generic
FILE: clu/asynclib_test.py
class AsyncWriterTest (line 23) | class AsyncWriterTest(absltest.TestCase):
method test_async_execution (line 25) | def test_async_execution(self):
method test_reraise (line 39) | def test_reraise(self):
method test_queue_length (line 65) | def test_queue_length(self, executor_mock):
method test_flush (line 95) | def test_flush(self, executor_mock):
FILE: clu/checkpoint.py
function safe_normpath (line 77) | def safe_normpath(path: str) -> str:
function load_state_dict (line 83) | def load_state_dict(base_directory) -> Dict[str, Any]:
class CheckpointInfo (line 106) | class CheckpointInfo(
method initialize (line 113) | def initialize(cls, base_directory, checkpoint_name: str) -> "Checkpoi...
method from_path (line 118) | def from_path(cls, checkpoint: str) -> "CheckpointInfo":
method increment (line 134) | def increment(self) -> "CheckpointInfo":
method __str__ (line 138) | def __str__(self):
class Checkpoint (line 143) | class Checkpoint:
method __init__ (line 162) | def __init__(self,
method get_latest_checkpoint_to_restore_from (line 194) | def get_latest_checkpoint_to_restore_from(self):
method latest_checkpoint (line 207) | def latest_checkpoint(self) -> Optional[str]:
method current_checkpoint (line 220) | def current_checkpoint(self) -> Optional[str]:
method _flax_path (line 240) | def _flax_path(self, checkpoint: str) -> str:
method _next_checkpoint (line 243) | def _next_checkpoint(self, checkpoint: Optional[str]) -> str:
method _checkpoint_number (line 249) | def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]:
method _delete_future_checkpoints (line 254) | def _delete_future_checkpoints(self):
method save (line 273) | def save(self, state) -> str:
method restore_or_initialize (line 330) | def restore_or_initialize(self, state: T) -> T:
method restore_dict (line 350) | def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, ...
method _checkpoint_or_latest (line 371) | def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str:
method load_state (line 378) | def load_state(self,
method restore (line 410) | def restore(self,
class MultihostCheckpoint (line 448) | class MultihostCheckpoint(Checkpoint):
method __init__ (line 463) | def __init__(self,
method get_latest_checkpoint_to_restore_from (line 503) | def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:
FILE: clu/checkpoint_test.py
function _make_dataset (line 26) | def _make_dataset():
class TrainState (line 34) | class TrainState:
class TrainStateExtended (line 39) | class TrainStateExtended:
class NotTrainState (line 44) | class NotTrainState:
function _checkpoint_number (line 48) | def _checkpoint_number(path):
class CheckpointTest (line 54) | class CheckpointTest(tf.test.TestCase):
method test_safe_normpath (line 56) | def test_safe_normpath(self):
method test_initialize_mkdir (line 63) | def test_initialize_mkdir(self):
method test_restores_flax_state (line 75) | def test_restores_flax_state(self):
method test_load_state_dict (line 102) | def test_load_state_dict(self):
method test_fails_when_restoring_subset (line 114) | def test_fails_when_restoring_subset(self):
method test_fails_when_restoring_superset (line 125) | def test_fails_when_restoring_superset(self):
method test_restores_tf_state (line 136) | def test_restores_tf_state(self):
method test_restore_flax_alone (line 172) | def test_restore_flax_alone(self):
method test_restore_dict (line 185) | def test_restore_dict(self):
method test_ignores_incomplete_checkpoint (line 213) | def test_ignores_incomplete_checkpoint(self):
method test_max_to_keep (line 249) | def test_max_to_keep(self):
method test_checkpoint_name (line 262) | def test_checkpoint_name(self):
method test_fails_if_not_registered (line 269) | def test_fails_if_not_registered(self):
method test_overwrite (line 276) | def test_overwrite(self):
class MultihostCheckpoint (line 312) | class MultihostCheckpoint(tf.test.TestCase):
method test_initialize_mkdir (line 315) | def test_initialize_mkdir(self, process_index_mock):
method test_synchronize_multiple_hosts (line 328) | def test_synchronize_multiple_hosts(self, process_index_mock):
method test_preemption (line 360) | def test_preemption(self):
FILE: clu/data/dataset_iterator.py
class ArraySpec (line 52) | class ArraySpec:
method __repr__ (line 57) | def __repr__(self):
method __str__ (line 60) | def __str__(self):
class DatasetIterator (line 77) | class DatasetIterator(collections.abc.Iterator): # pytype: disable=igno...
method get_next (line 92) | def get_next(self) -> Element:
method reset (line 99) | def reset(self):
method element_spec (line 105) | def element_spec(self) -> ElementSpec:
method save (line 109) | def save(self, filename: epath.Path):
method restore (line 119) | def restore(self, filename: epath.Path):
method load (line 129) | def load(self, filename: epath.Path):
class TfDatasetIterator (line 134) | class TfDatasetIterator(DatasetIterator):
method __init__ (line 137) | def __init__(self, dataset, *, checkpoint: bool):
method get_next (line 174) | def get_next(self) -> Element:
method __next__ (line 177) | def __next__(self) -> Element:
method reset (line 180) | def reset(self):
method element_spec (line 185) | def element_spec(self) -> ElementSpec:
method save (line 202) | def save(self, filename: epath.Path):
method restore (line 206) | def restore(self, filename: epath.Path):
class PeekableDatasetIterator (line 211) | class PeekableDatasetIterator(DatasetIterator):
method __init__ (line 230) | def __init__(self, it: DatasetIterator):
method __next__ (line 238) | def __next__(self) -> Element:
method reset (line 246) | def reset(self):
method element_spec (line 254) | def element_spec(self) -> ElementSpec:
method peek (line 257) | def peek(self) -> Element:
method peek_async (line 270) | def peek_async(self) -> concurrent.futures.Future[Element]:
method save (line 286) | def save(self, filename: epath.Path):
method restore (line 290) | def restore(self, filename: epath.Path):
FILE: clu/data/dataset_iterator_test.py
class DatasetIteratorTest (line 28) | class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):
method _create_iterator (line 30) | def _create_iterator(self, start_index: int, checkpoint: bool = True):
method test_tf_iterator (line 40) | def test_tf_iterator(self):
method test_tf_iterator_save_and_load (line 53) | def test_tf_iterator_save_and_load(self):
method test_tf_iterator_save_and_load_no_checkpoint (line 70) | def test_tf_iterator_save_and_load_no_checkpoint(self):
method test_peekable_dataset_iterator (line 84) | def test_peekable_dataset_iterator(self):
method test_peekable_dataset_iterator_async (line 92) | def test_peekable_dataset_iterator_async(self, wait: bool, peek_first:...
FILE: clu/deterministic_data.py
class DatasetBuilder (line 82) | class DatasetBuilder(typing_extensions.Protocol):
method as_dataset (line 85) | def as_dataset(
class RemainderOptions (line 92) | class RemainderOptions(enum.Enum):
function _shard_read_instruction (line 111) | def _shard_read_instruction(
function get_read_instruction_for_host (line 175) | def get_read_instruction_for_host(
function _preprocess_with_per_example_rng (line 272) | def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
function pad_dataset (line 303) | def pad_dataset(dataset: tf.data.Dataset,
function create_dataset (line 362) | def create_dataset(dataset_builder: DatasetBuilder,
function create_distributed_dataset (line 484) | def create_distributed_dataset(
FILE: clu/deterministic_data_test.py
class MyDatasetBuilder (line 35) | class MyDatasetBuilder:
method as_dataset (line 39) | def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: ...
class FakeDatasetInfo (line 57) | class FakeDatasetInfo:
method splits (line 62) | def splits(self):
class DeterministicDataTest (line 69) | class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
method test_get_read_instruction_for_host_deprecated (line 86) | def test_get_read_instruction_for_host_deprecated(self, num_examples: ...
method test_get_read_instruction_for_host (line 140) | def test_get_read_instruction_for_host(self, host_id: int, host_count:...
method test_get_read_instruction_balance_remainder (line 168) | def test_get_read_instruction_balance_remainder(self, host_id: int,
method test_get_read_instruction_for_host_fails (line 191) | def test_get_read_instruction_for_host_fails(self, host_id: int,
method test_preprocess_with_per_example_rng (line 197) | def test_preprocess_with_per_example_rng(self):
method test_create_dataset_padding (line 223) | def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
method test_create_dataset_padding_raises_error_cardinality (line 258) | def test_create_dataset_padding_raises_error_cardinality(self):
method test_pad_dataset (line 278) | def test_pad_dataset(self):
method test_pad_nested_dataset (line 292) | def test_pad_nested_dataset(self):
method test_same_cardinality_on_all_hosts (line 309) | def test_same_cardinality_on_all_hosts(self, num_examples: int,
method test_same_cardinality_on_all_hosts_with_pad (line 326) | def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
FILE: clu/internal/utils.py
function log_activity (line 30) | def log_activity(activity_name: str):
function logged_with (line 47) | def logged_with(activity_name: str):
function check_param (line 57) | def check_param(value, *, ndim=None, dtype=jnp.float32):
function flatten_dict (line 77) | def flatten_dict(
FILE: clu/internal/utils_test.py
class TestError (line 23) | class TestError(BaseException):
class HelpersTest (line 28) | class HelpersTest(absltest.TestCase):
method test_log_activity (line 30) | def test_log_activity(
method test_log_activity_fails (line 41) | def test_log_activity_fails(
method test_logged_with (line 53) | def test_logged_with(self):
method test_logged_with_fails (line 66) | def test_logged_with_fails(self):
method test_check_param (line 80) | def test_check_param(self):
method test_flatten_dict (line 91) | def test_flatten_dict(self):
FILE: clu/metric_writers/async_writer.py
function _wrap_exceptions (line 39) | def _wrap_exceptions(wrapped, instance, args, kwargs):
class AsyncWriter (line 51) | class AsyncWriter(interface.MetricWriter):
method __init__ (line 64) | def __init__(self,
method write_summaries (line 78) | def write_summaries(
method write_scalars (line 86) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 90) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 94) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 98) | def write_audios(
method write_texts (line 104) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 108) | def write_histograms(self,
method write_pointcloud (line 116) | def write_pointcloud(
method write_hparams (line 132) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 135) | def flush(self):
method close (line 141) | def close(self):
class AsyncMultiWriter (line 148) | class AsyncMultiWriter(multi_writer.MultiWriter):
method __init__ (line 151) | def __init__(self,
function ensure_flushes (line 159) | def ensure_flushes(*writers: interface.MetricWriter):
FILE: clu/metric_writers/async_writer_test.py
class AsyncWriterTest (line 27) | class AsyncWriterTest(tf.test.TestCase):
method setUp (line 29) | def setUp(self):
method test_write_summaries_async (line 34) | def test_write_summaries_async(self):
method test_write_scalars_async (line 46) | def test_write_scalars_async(self):
method test_write_images (line 61) | def test_write_images(self):
method test_write_videos (line 68) | def test_write_videos(self):
method test_write_pointcloud (line 75) | def test_write_pointcloud(self):
method test_write_texts (line 96) | def test_write_texts(self):
method test_ensure_flushes (line 101) | def test_ensure_flushes(self):
method test_ensure_flushes_with_multiple_writers (line 117) | def test_ensure_flushes_with_multiple_writers(self):
method test_flush_before_close (line 142) | def test_flush_before_close(self):
method test_reraises_exception (line 147) | def test_reraises_exception(self):
FILE: clu/metric_writers/interface.py
class MetricWriter (line 33) | class MetricWriter(abc.ABC):
method write_summaries (line 37) | def write_summaries(
method write_scalars (line 53) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 65) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 83) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 104) | def write_audios(
method write_texts (line 126) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 137) | def write_histograms(self,
method write_pointcloud (line 156) | def write_pointcloud(
method write_hparams (line 177) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 187) | def flush(self):
method close (line 191) | def close(self):
FILE: clu/metric_writers/logging_writer.py
class LoggingWriter (line 28) | class LoggingWriter(interface.MetricWriter):
method __init__ (line 31) | def __init__(self, collection: Optional[str] = None):
method write_summaries (line 37) | def write_summaries(
method write_scalars (line 44) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 51) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 55) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 59) | def write_audios(
method write_texts (line 64) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 67) | def write_histograms(self,
method write_pointcloud (line 80) | def write_pointcloud(
method write_hparams (line 101) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 104) | def flush(self):
method close (line 107) | def close(self):
function _compute_histogram_as_tf (line 111) | def _compute_histogram_as_tf(
function _get_histogram_as_string (line 152) | def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):
FILE: clu/metric_writers/logging_writer_test.py
class LoggingWriterTest (line 22) | class LoggingWriterTest(tf.test.TestCase):
method setUp (line 24) | def setUp(self):
method test_write_scalars (line 28) | def test_write_scalars(self):
method test_write_images (line 36) | def test_write_images(self):
method test_write_videos (line 44) | def test_write_videos(self):
method test_write_texts (line 52) | def test_write_texts(self):
method test_write_histogram (line 59) | def test_write_histogram(self):
method test_write_pointcloud (line 83) | def test_write_pointcloud(self):
method test_write_hparams (line 106) | def test_write_hparams(self):
method test_collection (line 113) | def test_collection(self):
FILE: clu/metric_writers/multi_writer.py
class MultiWriter (line 26) | class MultiWriter(interface.MetricWriter):
method __init__ (line 29) | def __init__(self, writers: Sequence[interface.MetricWriter]):
method write_summaries (line 32) | def write_summaries(
method write_scalars (line 39) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 43) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 47) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 51) | def write_audios(
method write_texts (line 56) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 60) | def write_histograms(self,
method write_pointcloud (line 67) | def write_pointcloud(
method write_hparams (line 80) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 84) | def flush(self):
method close (line 88) | def close(self):
FILE: clu/metric_writers/multi_writer_test.py
class MultiWriterTest (line 25) | class MultiWriterTest(tf.test.TestCase):
method setUp (line 27) | def setUp(self):
method test_write_scalars (line 35) | def test_write_scalars(self):
method test_write_pointcloud (line 52) | def test_write_pointcloud(self):
FILE: clu/metric_writers/tf/summary_writer.py
class SummaryWriter (line 42) | class SummaryWriter(interface.MetricWriter):
method __init__ (line 45) | def __init__(self, logdir: str):
method write_summaries (line 50) | def write_summaries(
method write_scalars (line 61) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 66) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 73) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 78) | def write_audios(
method write_texts (line 85) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 90) | def write_histograms(
method write_pointcloud (line 101) | def write_pointcloud(
method write_hparams (line 121) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 125) | def flush(self):
method close (line 128) | def close(self):
FILE: clu/metric_writers/tf/summary_writer_test.py
function _load_summaries_data (line 27) | def _load_summaries_data(logdir):
function _load_histograms_data (line 41) | def _load_histograms_data(logdir):
function _load_scalars_data (line 60) | def _load_scalars_data(logdir: str):
function _load_pointcloud_data (line 72) | def _load_pointcloud_data(logdir: str):
function _load_hparams (line 87) | def _load_hparams(logdir: str):
class SummaryWriterTest (line 101) | class SummaryWriterTest(tf.test.TestCase):
method setUp (line 103) | def setUp(self):
method test_write_summaries (line 108) | def test_write_summaries(self):
method test_write_scalar (line 121) | def test_write_scalar(self):
method test_write_histograms (line 129) | def test_write_histograms(self):
method test_write_pointcloud (line 160) | def test_write_pointcloud(self):
method test_hparams (line 178) | def test_hparams(self):
method test_hparams_nested (line 187) | def test_hparams_nested(self):
FILE: clu/metric_writers/torch_tensorboard_writer.py
class TorchTensorboardWriter (line 32) | class TorchTensorboardWriter(interface.MetricWriter):
method __init__ (line 35) | def __init__(self, logdir: str):
method write_summaries (line 40) | def write_summaries(
method write_scalars (line 48) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
method write_images (line 52) | def write_images(self, step: int, images: Mapping[str, Array]):
method write_videos (line 56) | def write_videos(self, step: int, videos: Mapping[str, Array]):
method write_audios (line 61) | def write_audios(
method write_texts (line 67) | def write_texts(self, step: int, texts: Mapping[str, str]):
method write_histograms (line 72) | def write_histograms(self,
method write_pointcloud (line 81) | def write_pointcloud(
method write_hparams (line 95) | def write_hparams(self, hparams: Mapping[str, Any]):
method flush (line 98) | def flush(self):
method close (line 101) | def close(self):
FILE: clu/metric_writers/torch_tensorboard_writer_test.py
function _load_scalars_data (line 26) | def _load_scalars_data(logdir: str):
function _load_histograms_data (line 38) | def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
class TorchTensorboardWriterTest (line 62) | class TorchTensorboardWriterTest(tf.test.TestCase):
method setUp (line 64) | def setUp(self):
method test_write_scalar (line 69) | def test_write_scalar(self):
method test_write_histograms (line 77) | def test_write_histograms(self):
FILE: clu/metric_writers/utils.py
function _is_scalar (line 46) | def _is_scalar(value: Any) -> bool:
function write_values (line 56) | def write_values(
function create_default_writer (line 113) | def create_default_writer(
FILE: clu/metric_writers/utils_test.py
class HistogramMetric (line 39) | class HistogramMetric(clu.metrics.Metric):
method compute_value (line 43) | def compute_value(self):
class ImageMetric (line 48) | class ImageMetric(clu.metrics.Metric):
method compute_value (line 51) | def compute_value(self):
class AudioMetric (line 56) | class AudioMetric(clu.metrics.Metric):
method compute_value (line 60) | def compute_value(self):
class TextMetric (line 65) | class TextMetric(clu.metrics.Metric):
method compute_value (line 68) | def compute_value(self):
class HyperParamMetric (line 73) | class HyperParamMetric(clu.metrics.Metric):
method compute_value (line 76) | def compute_value(self):
class SummaryMetric (line 81) | class SummaryMetric(clu.metrics.Metric):
method compute_value (line 85) | def compute_value(self):
function _to_summary (line 89) | def _to_summary(metrics):
function _to_list_of_dicts (line 93) | def _to_list_of_dicts(d):
class ONEOF (line 97) | class ONEOF(object):
method __init__ (line 100) | def __init__(self, container):
method __eq__ (line 107) | def __eq__(self, o):
method __ne__ (line 110) | def __ne__(self, o):
method __repr__ (line 113) | def __repr__(self):
class MetricWriterTest (line 117) | class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):
method test_write (line 119) | def test_write(self):
method test_create_default_writer_summary_writer_is_added (line 198) | def test_create_default_writer_summary_writer_is_added(self):
FILE: clu/metrics.py
class FromFunCallable (line 76) | class FromFunCallable(Protocol):
method __call__ (line 79) | def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
function _assert_same_shape (line 88) | def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
class Metric (line 97) | class Metric:
method from_model_output (line 133) | def from_model_output(cls: type[M], *args, **kwargs) -> M:
method merge (line 137) | def merge(self: M, other: M) -> M:
method _reduce_merge (line 160) | def _reduce_merge(self: M, other: M) -> M:
method compute (line 163) | def compute(self) -> jnp.ndarray:
method empty (line 168) | def empty(cls: type[M]) -> M:
method compute_value (line 172) | def compute_value(self) -> clu.values.Value:
method reduce (line 176) | def reduce(self: M) -> M:
method from_fun (line 235) | def from_fun(cls, fun: FromFunCallable): # No way to annotate return ...
method from_output (line 314) | def from_output(cls, name: str): # No way to annotate return type
class CollectingMetric (line 358) | class CollectingMetric(Metric):
method empty (line 416) | def empty(cls) -> CollectingMetric:
method merge (line 419) | def merge(self, other: CollectingMetric) -> CollectingMetric:
method reduce (line 434) | def reduce(self) -> CollectingMetric:
method compute (line 440) | def compute(self): # No return type annotation, so subclasses can ove...
method from_outputs (line 444) | def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]:
class _ReductionCounter (line 465) | class _ReductionCounter(Metric):
method empty (line 471) | def empty(cls) -> _ReductionCounter:
method merge (line 474) | def merge(self, other: _ReductionCounter) -> _ReductionCounter:
function _check_reduction_counter_ndim (line 478) | def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
class Collection (line 490) | class Collection:
method create (line 512) | def create(cls, **metrics: type[Metric]) -> type[Collection]:
method create_collection (line 538) | def create_collection(cls, **metrics: Metric) -> Collection:
method empty (line 566) | def empty(cls: type[C]) -> C:
method _from_model_output (line 576) | def _from_model_output(cls: type[C], **kwargs) -> C:
method single_from_model_output (line 587) | def single_from_model_output(cls: type[C], **kwargs) -> C:
method gather_from_model_output (line 602) | def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs...
method merge (line 617) | def merge(self: C, other: C) -> C:
method reduce (line 624) | def reduce(self: C) -> C:
method compute (line 656) | def compute(self) -> dict[str, jnp.ndarray]:
method compute_values (line 665) | def compute_values(self) -> dict[str, clu.values.Value]:
method unreplicate (line 674) | def unreplicate(self: C) -> C:
class LastValue (line 692) | class LastValue(Metric):
method __init__ (line 707) | def __init__( # pytype: disable=missing-parameter # jnp-array
method empty (line 743) | def empty(cls) -> LastValue:
method from_model_output (line 747) | def from_model_output(
method merge (line 757) | def merge(self, other: LastValue) -> LastValue:
method _reduce_merge (line 761) | def _reduce_merge(self, other: LastValue) -> LastValue:
method value (line 770) | def value(self) -> jnp.ndarray:
method compute (line 775) | def compute(self) -> Any:
function _broadcast_masks (line 779) | def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None):
class Average (line 801) | class Average(Metric):
method empty (line 820) | def empty(cls) -> Average:
method from_model_output (line 824) | def from_model_output(
method merge (line 837) | def merge(self, other: Average) -> Average:
method compute (line 844) | def compute(self) -> Any:
class Std (line 849) | class Std(Metric):
method empty (line 861) | def empty(cls) -> Std:
method from_model_output (line 868) | def from_model_output(
method merge (line 882) | def merge(self, other: Std) -> Std:
method compute (line 890) | def compute(self) -> Any:
class Accuracy (line 906) | class Accuracy(Average):
method from_model_output (line 916) | def from_model_output(
FILE: clu/metrics_test.py
class CollectingMetricAccuracy (line 32) | class CollectingMetricAccuracy(
method compute (line 35) | def compute(self):
class Collection (line 45) | class Collection(metrics.Collection):
class CollectionMixed (line 51) | class CollectionMixed(metrics.Collection):
class MetricsTest (line 56) | class MetricsTest(parameterized.TestCase):
method setUp (line 58) | def setUp(self):
method make_compute_metric (line 117) | def make_compute_metric(self, metric_class, reduce, jit=True):
method test_metric_last_value_reduce (line 151) | def test_metric_last_value_reduce(self):
method test_metric_last_value (line 177) | def test_metric_last_value(self):
method test_metric_last_value_legacy_kwarg_value (line 192) | def test_metric_last_value_legacy_kwarg_value(self):
method test_metric_last_value_tree_manipulation (line 198) | def test_metric_last_value_tree_manipulation(self):
method test_from_fun_with_single_output (line 213) | def test_from_fun_with_single_output(self):
method test_from_fun_with_mapping_output (line 229) | def test_from_fun_with_mapping_output(self):
method test_average_masked (line 262) | def test_average_masked(self, values, mask, expected_result):
method test_merge_asserts_shape (line 284) | def test_merge_asserts_shape(self, metric_cls):
method test_accuracy (line 296) | def test_accuracy(self, reduce):
method test_last_value_asserts_shape (line 301) | def test_last_value_asserts_shape(self):
method test_loss_average (line 313) | def test_loss_average(self, reduce):
method test_loss_std (line 328) | def test_loss_std(self, reduce):
method test_collection_create (line 341) | def test_collection_create(self):
method test_collection_create_custom_mask (line 350) | def test_collection_create_custom_mask(self):
method test_collection_create_collection (line 374) | def test_collection_create_collection(self):
method test_collection_single (line 394) | def test_collection_single(self, masked):
method test_collection_gather (line 418) | def test_collection_gather(self, masked, all_gather_mock):
method test_collection_gather_pmap (line 442) | def test_collection_gather_pmap(self, masked):
method test_collection_asserts_replication (line 455) | def test_collection_asserts_replication(self):
method test_collecting_metric (line 466) | def test_collecting_metric(self):
method test_collecting_metric_reduce (line 480) | def test_collecting_metric_reduce(self):
method test_collecting_metric_async (line 486) | def test_collecting_metric_async(self):
method test_collecting_metric_tracer (line 505) | def test_collecting_metric_tracer(self):
method test_collection_mixed_async (line 512) | def test_collection_mixed_async(self):
method test_metric_empty_types_doesnt_cause_retrace (line 530) | def test_metric_empty_types_doesnt_cause_retrace(self):
method test_tensor_aggregation_metrics_with_masks (line 569) | def test_tensor_aggregation_metrics_with_masks(
FILE: clu/parameter_overview.py
class _ParamRow (line 32) | class _ParamRow:
class _ParamRowWithSharding (line 40) | class _ParamRowWithSharding(_ParamRow):
class _ParamRowWithStats (line 45) | class _ParamRowWithStats(_ParamRow):
class _ParamRowWithStatsAndSharding (line 51) | class _ParamRowWithStatsAndSharding(_ParamRowWithStats):
function _mean_std_jit (line 56) | def _mean_std_jit(x):
function _mean_std (line 60) | def _mean_std(x):
function flatten_dict (line 66) | def flatten_dict(
function _count_parameters (line 82) | def _count_parameters(params: _ParamsContainer) -> int:
function _parameters_size (line 88) | def _parameters_size(params: _ParamsContainer) -> int:
function count_parameters (line 98) | def count_parameters(params: _ParamsContainer) -> int:
function _make_row (line 104) | def _make_row(name, value) -> _ParamRow:
function _make_row_with_sharding (line 120) | def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
function _make_row_with_stats (line 132) | def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
function _make_row_with_stats_and_sharding (line 143) | def _make_row_with_stats_and_sharding(
function _get_parameter_rows (line 154) | def _get_parameter_rows(
function _default_table_value_formatter (line 209) | def _default_table_value_formatter(value):
function make_table (line 221) | def make_table(
function _get_parameter_overview (line 286) | def _get_parameter_overview(
function get_parameter_overview (line 310) | def get_parameter_overview(
function _log_parameter_overview (line 347) | def _log_parameter_overview(
function log_parameter_overview (line 368) | def log_parameter_overview(
FILE: clu/parameter_overview_test.py
class CNN (line 72) | class CNN(nn.Module):
method __call__ (line 75) | def __call__(self, x):
class JaxParameterOverviewTest (line 79) | class JaxParameterOverviewTest(absltest.TestCase):
method test_count_parameters_empty (line 81) | def test_count_parameters_empty(self):
method test_count_parameters (line 84) | def test_count_parameters(self):
method test_get_parameter_overview_empty (line 92) | def test_get_parameter_overview_empty(self):
method test_get_parameter_overview (line 98) | def test_get_parameter_overview(self):
method test_get_parameter_overview_shape_dtype_struct (line 126) | def test_get_parameter_overview_shape_dtype_struct(self):
method test_printing_bool (line 134) | def test_printing_bool(self):
FILE: clu/periodic_actions.py
function _squareit (line 44) | def _squareit(x):
function _format_secs (line 49) | def _format_secs(secs: float):
class PeriodicAction (line 65) | class PeriodicAction(abc.ABC):
method __init__ (line 74) | def __init__(self,
method _init_and_check (line 99) | def _init_and_check(self, step: int, t: float):
method _should_trigger (line 112) | def _should_trigger(self, step: int, t: float) -> bool:
method _after_apply (line 123) | def _after_apply(self, step: int, t: float):
method __call__ (line 128) | def __call__(self, step: int, t: Optional[float] = None) -> bool:
method _apply (line 150) | def _apply(self, step: int, t: float):
class ReportProgress (line 154) | class ReportProgress(PeriodicAction):
method __init__ (line 157) | def __init__(self,
method set_persistent_notes (line 197) | def set_persistent_notes(self, message: str):
method _should_trigger (line 201) | def _should_trigger(self, step: int, t: float) -> bool:
method _apply (line 205) | def _apply(self, step: int, t: float):
method timed (line 227) | def timed(self, name: str, wait_jax_async_dispatch: bool = True):
class Profile (line 304) | class Profile(PeriodicAction):
method __init__ (line 309) | def __init__(
method _should_trigger (line 350) | def _should_trigger(self, step: int, t: float) -> bool:
method _apply (line 364) | def _apply(self, step: int, t: float):
method _start_session (line 368) | def _start_session(self):
method _end_session (line 376) | def _end_session(self, url: Optional[str]):
class ProfileAllHosts (line 385) | class ProfileAllHosts(PeriodicAction):
method __init__ (line 390) | def __init__(self,
method _should_trigger (line 419) | def _should_trigger(self, step: int, t: float) -> bool:
method _apply (line 422) | def _apply(self, step: int, t: float):
method _start_session (line 426) | def _start_session(self):
method _end_session (line 435) | def _end_session(self, url: Optional[str], *, step: int):
class PeriodicCallback (line 443) | class PeriodicCallback(PeriodicAction):
method __init__ (line 446) | def __init__(self,
method __call__ (line 476) | def __call__(self, step: int, t: Optional[float] = None, **kwargs) -> ...
method get_last_callback_result (line 488) | def get_last_callback_result(self):
method _apply (line 492) | def _apply(self, step, t, **kwargs):
FILE: clu/periodic_actions_test.py
class ReportProgressTest (line 26) | class ReportProgressTest(parameterized.TestCase):
method test_every_steps (line 28) | def test_every_steps(self):
method test_every_secs (line 50) | def test_every_secs(self):
method test_without_num_train_steps (line 72) | def test_without_num_train_steps(self):
method test_with_persistent_notes (line 83) | def test_with_persistent_notes(self):
method test_unknown_cardinality (line 96) | def test_unknown_cardinality(self):
method test_called_every_step (line 107) | def test_called_every_step(self):
method test_named (line 121) | def test_named(self, wait_jax_async_dispatch, mock_time):
method test_write_metrics (line 156) | def test_write_metrics(self, time_mock):
class DummyProfilerSession (line 175) | class DummyProfilerSession:
method __init__ (line 178) | def __init__(self):
method start_session (line 183) | def start_session(self):
method end_session_and_get_url (line 186) | def end_session_and_get_url(self, tag):
class ProfileTest (line 191) | class ProfileTest(absltest.TestCase):
method test_every_steps (line 195) | def test_every_steps(self, mock_time, mock_profiler):
class ProfileAllHostsTest (line 224) | class ProfileAllHostsTest(absltest.TestCase):
method test_every_steps (line 227) | def test_every_steps(self, mock_profiler):
class PeriodicCallbackTest (line 247) | class PeriodicCallbackTest(absltest.TestCase):
method test_every_steps (line 249) | def test_every_steps(self):
method test_every_secs (line 267) | def test_every_secs(self, mock_time):
method test_on_steps (line 281) | def test_on_steps(self):
method test_async_execution (line 290) | def test_async_execution(self):
method test_error_async_is_forwarded (line 309) | def test_error_async_is_forwarded(self):
method test_function_without_step_and_time (line 325) | def test_function_without_step_and_time(self):
FILE: clu/platform/__init__.py
function work_unit (line 35) | def work_unit() -> WorkUnit:
FILE: clu/platform/interface.py
class ArtifactType (line 22) | class ArtifactType(enum.Enum):
class WorkUnit (line 31) | class WorkUnit(abc.ABC):
method experiment_id (line 42) | def experiment_id(self):
method id (line 47) | def id(self):
method name (line 51) | def name(self):
method set_notes (line 63) | def set_notes(self, msg: str):
method set_task_status (line 67) | def set_task_status(self, msg: str):
method create_artifact (line 71) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
FILE: clu/platform/local.py
class LocalWorkUnit (line 26) | class LocalWorkUnit(WorkUnit):
method experiment_id (line 30) | def experiment_id(self):
method id (line 35) | def id(self):
method set_notes (line 39) | def set_notes(self, msg: str):
method set_task_status (line 43) | def set_task_status(self, msg: str):
method create_artifact (line 47) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,
FILE: clu/preprocess_spec.py
class PreprocessOp (line 77) | class PreprocessOp(Protocol):
method __call__ (line 90) | def __call__(self, features: Features) -> Features:
class MapTransform (line 95) | class MapTransform(abc.ABC):
method __new__ (line 109) | def __new__(cls, *args, **kwargs):
method __call__ (line 121) | def __call__(self, features: D) -> D:
method _transform (line 130) | def _transform(self, features: FlatFeatures) -> FlatFeatures:
class RandomMapTransform (line 135) | class RandomMapTransform(MapTransform, abc.ABC):
method __call__ (line 149) | def __call__(self, features: D) -> D:
method _transform (line 162) | def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatF...
class FilterTransform (line 167) | class FilterTransform(abc.ABC):
method __call__ (line 169) | def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
method _predicate (line 175) | def _predicate(self, features: FlatFeatures) -> tf.Tensor:
function get_all_ops (line 179) | def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]:
function _jax_supported_tf_types (line 204) | def _jax_supported_tf_types():
class OnlyJaxTypes (line 214) | class OnlyJaxTypes:
method __call__ (line 228) | def __call__(self, features: Features) -> Features:
class PreprocessFn (line 252) | class PreprocessFn:
method __call__ (line 265) | def __call__(self, features: Features) -> Features:
method __add__ (line 280) | def __add__(self, other: "PreprocessFn") -> "PreprocessFn":
method __getitem__ (line 289) | def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn":
function _get_op_class (line 298) | def _get_op_class(
function _parse_single_preprocess_op (line 315) | def _parse_single_preprocess_op(
function parse (line 359) | def parse(spec: str,
function _describe_features (line 374) | def _describe_features(features: Features) -> str:
FILE: clu/preprocess_spec_test.py
class ToFloat (line 27) | class ToFloat:
method __call__ (line 29) | def __call__(self, features: Features) -> Features:
class Rescale (line 34) | class Rescale:
method __call__ (line 38) | def __call__(self, features: Features) -> Features:
class AddRandomInteger (line 45) | class AddRandomInteger(preprocess_spec.RandomMapTransform):
method _transform (line 47) | def _transform(self, features, seed):
class PreprocessSpecTest (line 55) | class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase):
method test_no_arguments (line 58) | def test_no_arguments(self):
method test_positional_argument (line 63) | def test_positional_argument(self):
method test_keyword_argument (line 69) | def test_keyword_argument(self):
method test_invalid_op_name (line 75) | def test_invalid_op_name(self):
method test_invalid_spec (line 82) | def test_invalid_spec(self):
method test_pos_and_kw_arg (line 87) | def test_pos_and_kw_arg(self):
method test_parsing_empty_string (line 95) | def test_parsing_empty_string(self):
method test_multi_op_spec (line 100) | def test_multi_op_spec(self):
method test_two_tensors (line 105) | def test_two_tensors(self):
method test_only_jax_types (line 114) | def test_only_jax_types(self):
method test_only_jax_types_nested_inputs (line 128) | def test_only_jax_types_nested_inputs(self):
method test_not_only_jax_types (line 139) | def test_not_only_jax_types(self):
method test_add_preprocess_fn (line 145) | def test_add_preprocess_fn(self):
method test_slice_preprocess_fn (line 156) | def test_slice_preprocess_fn(self):
method test_random_map_transform (line 166) | def test_random_map_transform(self):
FILE: clu/profiler.py
function start (line 29) | def start(logdir: str, options=None):
function stop (line 41) | def stop() -> Optional[str]:
function collect (line 49) | def collect(logdir: str,
FILE: clu/values.py
class Value (line 31) | class Value(Protocol):
class Summary (line 41) | class Summary(Value):
class Scalar (line 47) | class Scalar(Value):
class Image (line 52) | class Image(Value):
class Audio (line 65) | class Audio(Value):
class Text (line 81) | class Text(Value):
class Histogram (line 86) | class Histogram(Value):
class HyperParam (line 93) | class HyperParam(Value):
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,228K chars).
[
{
"path": ".github/workflows/build.yml",
"chars": 1049,
"preview": "# This workflow will install Python dependencies, run tests and lint.\n# For more information see: https://help.github.co"
},
{
"path": ".github/workflows/python-publish.yml",
"chars": 846,
"preview": "# This workflows will upload a Python Package using Twine when a release is created\n# For more information see: https://"
},
{
"path": "AUTHORS",
"chars": 307,
"preview": "# This is the list of Common Loop Utils significant contributors.\n#\n# This does not necessarily list everyone who has co"
},
{
"path": "CHANGELOG.md",
"chars": 3390,
"preview": "# Changelog\n\n## v0.0.1-alpha.1\n\nInitial PyPi Release\n\nCurrent list of modules:\n\n- `clu.checkpoint`\n- `clu.determinis"
},
{
"path": "CONTRIBUTING.md",
"chars": 923,
"preview": "# How to Contribute\n\nAt this time we are focused on supporting research done by Google Research and\nare not accepting pa"
},
{
"path": "LICENSE",
"chars": 11358,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 761,
"preview": "# CLU - Common Loop Utils\n\nThis repository contains common functionality for writing ML training loops. The\ngoal is to m"
},
{
"path": "clu/__init__.py",
"chars": 581,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/asynclib.py",
"chars": 5155,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/asynclib_test.py",
"chars": 3341,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/checkpoint.py",
"chars": 21322,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/checkpoint_test.py",
"chars": 14741,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/data/__init__.py",
"chars": 892,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/data/dataset_iterator.py",
"chars": 9756,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/data/dataset_iterator_test.py",
"chars": 4146,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/deterministic_data.py",
"chars": 24040,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/deterministic_data_test.py",
"chars": 13529,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/internal/__init__.py",
"chars": 581,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/internal/utils.py",
"chars": 3249,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/internal/utils_test.py",
"chars": 3391,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/__init__.py",
"chars": 2411,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/async_writer.py",
"chars": 5496,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/async_writer_test.py",
"chars": 5076,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/interface.py",
"chars": 7012,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/logging_writer.py",
"chars": 5536,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/logging_writer_test.py",
"chars": 5045,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/multi_writer.py",
"chars": 2814,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/multi_writer_test.py",
"chars": 2289,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/summary_writer.py",
"chars": 711,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/tf/__init__.py",
"chars": 610,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/tf/summary_writer.py",
"chars": 4285,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/tf/summary_writer_test.py",
"chars": 7805,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/torch_tensorboard_writer.py",
"chars": 3285,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/torch_tensorboard_writer_test.py",
"chars": 3341,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/utils.py",
"chars": 5308,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metric_writers/utils_test.py",
"chars": 5967,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metrics.py",
"chars": 32656,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/metrics_test.py",
"chars": 20817,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/parameter_overview.py",
"chars": 12445,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/parameter_overview_test.py",
"chars": 6122,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/periodic_actions.py",
"chars": 18561,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/periodic_actions_test.py",
"chars": 9895,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/platform/__init__.py",
"chars": 1344,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/platform/interface.py",
"chars": 2075,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/platform/local.py",
"chars": 1634,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/preprocess_spec.py",
"chars": 14204,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/preprocess_spec_test.py",
"chars": 6550,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/profiler.py",
"chars": 1826,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu/run_pytest.google.sh",
"chars": 373,
"preview": "#!/bin/bash\n\nset -e -x\n\nCLU_DST=\"${CLU_DST:-/tmp/clu}\"\nCLU_ENV=\"${CLU_ENV:-/tmp/clu_env}\"\n\ncopybara third_party/py/clu/c"
},
{
"path": "clu/values.py",
"chars": 2772,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
},
{
"path": "clu_synopsis.ipynb",
"chars": 882835,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {\n \"id\": \"4CldxEhqQac_\"\n },\n \"sou"
},
{
"path": "setup.py",
"chars": 2052,
"preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
}
]
About this extraction
This page contains the full source code of the google/CommonLoopUtils GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (1.2 MB), approximately 657.2k tokens, and a symbol index with 535 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.