Full Code of google/CommonLoopUtils for AI

main 5150136b245e cached
52 files
1.2 MB
657.2k tokens
535 symbols
1 requests
Download .txt
Showing preview only (1,216K chars total). Download the full file or copy to clipboard to get everything.
Repository: google/CommonLoopUtils
Branch: main
Commit: 5150136b245e
Files: 52
Total size: 1.2 MB

Directory structure:
gitextract_g9bj2g2j/

├── .github/
│   └── workflows/
│       ├── build.yml
│       └── python-publish.yml
├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── clu/
│   ├── __init__.py
│   ├── asynclib.py
│   ├── asynclib_test.py
│   ├── checkpoint.py
│   ├── checkpoint_test.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset_iterator.py
│   │   └── dataset_iterator_test.py
│   ├── deterministic_data.py
│   ├── deterministic_data_test.py
│   ├── internal/
│   │   ├── __init__.py
│   │   ├── utils.py
│   │   └── utils_test.py
│   ├── metric_writers/
│   │   ├── __init__.py
│   │   ├── async_writer.py
│   │   ├── async_writer_test.py
│   │   ├── interface.py
│   │   ├── logging_writer.py
│   │   ├── logging_writer_test.py
│   │   ├── multi_writer.py
│   │   ├── multi_writer_test.py
│   │   ├── summary_writer.py
│   │   ├── tf/
│   │   │   ├── __init__.py
│   │   │   ├── summary_writer.py
│   │   │   └── summary_writer_test.py
│   │   ├── torch_tensorboard_writer.py
│   │   ├── torch_tensorboard_writer_test.py
│   │   ├── utils.py
│   │   └── utils_test.py
│   ├── metrics.py
│   ├── metrics_test.py
│   ├── parameter_overview.py
│   ├── parameter_overview_test.py
│   ├── periodic_actions.py
│   ├── periodic_actions_test.py
│   ├── platform/
│   │   ├── __init__.py
│   │   ├── interface.py
│   │   └── local.py
│   ├── preprocess_spec.py
│   ├── preprocess_spec_test.py
│   ├── profiler.py
│   ├── run_pytest.google.sh
│   └── values.py
├── clu_synopsis.ipynb
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/build.yml
================================================
# This workflow will install Python dependencies, run tests and lint.
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Build

on:
  push:
    branches:
      - main
      - 'test_*'
  pull_request:
    branches:
      - main

jobs:
  build:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ['3.10', '3.11']
    steps:
    - name: Cancel previous
      uses: styfle/cancel-workflow-action@0.8.0
      with:
        access_token: ${{ github.token }}
    - uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        pip install .
        pip install .[test]
    - name: Test with pytest and generate coverage report
      run: |
        pytest .
    - name: Upload coverage to Codecov
      uses: codecov/codecov-action@v1
      with:
        file: ./coverage.xml


================================================
FILE: .github/workflows/python-publish.yml
================================================
# This workflows will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Upload Python Package

on:
  release:
    types: [created]

jobs:
  deploy:

    runs-on: ubuntu-latest

    steps:
    - uses: actions/checkout@v3
    - name: Set up Python
      uses: actions/setup-python@v4
      with:
        python-version: '3.x'
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install setuptools wheel twine
    - name: Build and publish
      env:
        TWINE_USERNAME: __token__
        TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
      run: |
        python setup.py sdist bdist_wheel
        twine upload dist/*


================================================
FILE: AUTHORS
================================================
# This is the list of Common Loop Utils significant contributors.
#
# This does not necessarily list everyone who has contributed code,
# especially since many employees of one corporation may be contributing.
# To see the full list of contributors, see the revision history in
# source control.
Google LLC


================================================
FILE: CHANGELOG.md
================================================
# Changelog

## v0.0.1-alpha.1

Initial PyPi Release

Current list of modules:

-   `clu.checkpoint`
-   `clu.deterministic_training`
-   `clu.metric_writers`
-   `clu.periodic_actions`
-   `clu.platform`
-   `clu.profiler`

## v0.0.1-alpha.2

-   Adds `metrics` module and some minor changes.

## v0.0.1a3

-   Added `metric_writers.TorchTensorboardWriter`

## v0.0.2

-   Added preprocess_spec.
-   Improvements to periodic_actions.

## v0.0.3

-   `metric_writers`: Lets `SummaryWriter` write nested dictionaries.
-   `internal`: Adds `async.Pool`.
-   `preprocess_spec`: Support nested dictionaries.
-   `profile`: Use JAX profiler APIs instead of TF profiler APIs.

## v0.0.4

`deterministic_data`

-   Support non-positive input value for pad_up_to_batches.
-   Support padding dataset when data dimension is unknown.
-   Support TFDS specs in get_read_instruction_for_host.
-   Allow changing drop_remainder for batching.
-   Add RemainderOptions in deterministic_data.

`metric_writers`

-   Support multiple writers in metric_writers.ensure_flushes.

`metrics`

-   Makes internal.flatten_dict() work with ConfigDicts.
-   Forwards mask model output to metrics created via `Metric.from_output()`.
-   Forwards mask model output to metrics created via `Metric.from_fun()`.
-   Added `Collections.unreplicate()`, `Collections.create()`.

`periodic_actions`

-   Formats long time strings in '{days}d{hours}h{mins}m' format.

`preprocess_spec`

-   Make feature description of features in PreprocessFn more compact.
-   Better type check in `preprocess_spec.get_all_ops()`.

Documentation:

-   Added `clu_synopsis.ipynb` Colab

## v0.0.5

-   Log error instead of failing when `profiler.start()` raises an exception.
-   Makes `periodic_actions.ProgressUpdate` show total number of steps.
-   Makes `AsyncWriter` non-blocking wrt JAX async computations.
-   Adds `clu_synopsis.ipynb` Colab as initial documentation.
-   Restore Checkpoint without providing the state
-   Makes `PreprocessFn` addable.
-   Allow n-dimensional arrays (and masks) to be passed to Metrics.Average().
-   Support slicing `PreprocessFn`.

## v0.0.6

-   Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`.

This will be the last release supporting Python 3.6.

## v0.0.7

-   Moves `clu.internal.asynclib` to `clu.asynclib`.
-   Adds methods for writing raw tensors and audio to `MetricWriter`.
-   Adds `clu.values` to annotate arrays with a modality.
-   Adds `clu.data.DatasetIterator` - a generic interface between input
    pipelines and training loops.
-   Fixes various issues with `clu.metrics`.

This will be the last release supporting Python 3.7.

## v0.0.9

-   Fix pytype failures related to teaching pytype about NumPy scalar types.
-   Fix a couple of docstring typos.
-   Updates README and clu_synposis.ipynb

Last release before dropping support for Python 3.8 and 3.9

## v0.0.10

-   `clu.parameter_overview` now supports JAX global arrays.
-   Various small fixes in `clu.metrics` module.
-   Removed some tensorflow dependencies.

## v0.0.11

-   Removes numpy version pin
-   Adds sharding annotations, dtype, total bytes to `parameter_overview`
-   Makes `clu.metrics.Std` support same shapes as `clu.metrics.Average`

## v0.0.12

-   Switch from `jax.tree_map` (deprecated since JAX 0.4.26) to
    `jax.tree_util.tree_map`.
-   Improvements to parameter overview.


================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute

At this time we are focused on supporting research done by Google Research and
are not accepting patches.

You are however free to start of fork of the project for your purposes as
permitted by the license.

## Contributor License Agreement

Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
<https://cla.developers.google.com/> to see your current agreements on file or
to sign a new one.

You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.

## Community Guidelines

This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# CLU - Common Loop Utils

This repository contains common functionality for writing ML training loops. The
goal is to make trainings loops short and readable (but moving common tasks to
small libraries) without removing the flexibility required for research.

To get started, check out this Colab:

https://colab.research.google.com/github/google/CommonLoopUtils/blob/main/clu_synopsis.ipynb

If you're looking for usage examples, see:

https://github.com/google/flax/tree/main/examples

You can also find answers to common questions about CLU on Flax Github
discussions page:

https://github.com/google/flax/discussions

Note: As this point we are not accepting contributions. Please fork the
repository if you want to extend the libraries for your use case.


================================================
FILE: clu/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



================================================
FILE: clu/asynclib.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for async function calls."""

import collections
import concurrent.futures
import functools
import sys
import threading
from typing import Callable, List, Optional

from absl import logging


class AsyncError(Exception):
  """An exception that wraps another exception that ocurred asynchronously."""


class Pool:
  """Pool for wrapping functions to be executed asynchronously.

  Synopsis:

    from clu.internal import asynclib

    pool = asynclib.Pool()
    @pool
    def fn():
      time.sleep(1)

    future = fn()
    print(future.result())
    fn()  # This could re-raise an exception from the first execution.
    print(len(pool))  # Would print "1" because there is one function in flight.
    pool.flush()  # This could re-raise an exception from the second execution.
  """

  def __init__(self, thread_name_prefix: str = "",
               max_workers: Optional[int] = None):
    """Creates a new pool that decorates functions for async execution.

    Args:
      thread_name_prefix: See documentation of `ThreadPoolExecutor`.
      max_workers: See documentation of `ThreadPoolExecutor`. The default `None`
        optimizes for parallelizability using the number of CPU cores. If you
        specify `max_workers=1` you the async calls are executed in the same
        order they have been scheduled.
    """
    self._pool = concurrent.futures.ThreadPoolExecutor(
        max_workers=max_workers, thread_name_prefix=thread_name_prefix)
    self._max_workers = max_workers
    self._thread_name_prefix = thread_name_prefix
    self._errors = collections.deque()
    self._errors_mutex = threading.Lock()
    self._queue_length = 0

  def _reraise(self) -> None:
    if self._errors:
      with self._errors_mutex:
        exc_info = self._errors.popleft()
      exc = exc_info[1].with_traceback(exc_info[2])
      raise AsyncError(f"Error '{exc}' occurred ASYNCHRONOUSLY.") from exc

  def close(self) -> None:
    """Closes this pool & raise a pending exception (if needed)."""
    self._pool.shutdown(wait=True)
    self._reraise()

  def join(self) -> None:
    """Blocks until all functions are processed.

    The pool can be used to schedule more functions after calling this function,
    but there might be more exceptions

    Side-effect:
      If any of the functions raised an exception, then the first of these
      exceptions is reraised.
    """
    self._pool.shutdown(wait=True)
    self._pool = concurrent.futures.ThreadPoolExecutor(
        max_workers=self._max_workers,
        thread_name_prefix=self._thread_name_prefix)
    self._reraise()

  @property
  def queue_length(self) -> int:
    """Returns the number of functions that have not returned yet."""
    return self._queue_length

  @property
  def has_errors(self) -> bool:
    """Returns True if there are any pending errors."""
    return bool(self._errors)

  def clear_errors(self) -> List[Exception]:
    """Clears all pending errors and returns them as a (possibly empty) list."""
    with self._errors_mutex:
      errors, self._errors = self._errors, collections.deque()
    return list(errors)

  def __call__(self, fn: Callable):  # pylint: disable=g-bare-generic
    """Returns an async version of fn.

    The function will be executed by this class's ThreadPoolExecutor. Any errors
    will be stored and re-raised next time any function is called that is
    executed through this pool.

    Note that even if there was a previous error, the function is still
    scheduled upon re-execution of the wrapper returned by this function.

    Args:
      fn: Function to be wrapped.

    Returns:
      An async version of `fn`. The return value of that async version will be
      a future (unless an exception was re-raised).
    """

    def inner(*args, **kwargs):

      def trap_errors(*args, **kwargs):
        try:
          return fn(*args, **kwargs)
        except Exception as e:
          with self._errors_mutex:
            self._errors.append(sys.exc_info())
          logging.exception("Error in producer thread for %s",
                            self._thread_name_prefix)
          raise e
        finally:
          self._queue_length -= 1

      self._queue_length += 1
      if not self.has_errors:
        return self._pool.submit(trap_errors, *args, **kwargs)
      self._pool.submit(trap_errors, *args, **kwargs)
      self._reraise()

    if isinstance(fn.__name__, str):
      # Regular function.
      return functools.wraps(fn)(inner)
    # Mock or another weird function that fails with functools.wraps().
    return inner


================================================
FILE: clu/asynclib_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for clu.asynclib."""

from unittest import mock

from absl.testing import absltest
from clu import asynclib


class AsyncWriterTest(absltest.TestCase):

  def test_async_execution(self):
    pool = asynclib.Pool()
    counter = 0

    @pool
    def fn(counter_increment, return_value):
      nonlocal counter
      counter += counter_increment
      return return_value

    future = fn(1, return_value=2)
    self.assertEqual(counter, 1)
    self.assertEqual(future.result(), 2)

  def test_reraise(self):
    pool = asynclib.Pool()

    @pool
    def error():
      raise ValueError("test")

    error()
    self.assertTrue(pool.has_errors)
    with self.assertRaisesRegex(asynclib.AsyncError, "test"):
      pool.join()
    self.assertFalse(pool.has_errors)

    @pool
    def noop():
      ...

    error()
    self.assertTrue(pool.has_errors)
    with self.assertRaisesRegex(asynclib.AsyncError, "test"):
      noop()
    self.assertFalse(pool.has_errors)

    pool.join()

  @mock.patch("concurrent.futures.ThreadPoolExecutor")
  def test_queue_length(self, executor_mock):
    pool_mock = mock.Mock()
    in_flight = []

    def execute_one():
      in_flight.pop(0)()

    def submit(fn, *args, **kwargs):
      in_flight.append(lambda: fn(*args, **kwargs))

    pool_mock.submit = submit
    executor_mock.return_value = pool_mock

    pool = asynclib.Pool()

    @pool
    def noop():
      ...

    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)
    noop()
    self.assertEqual(pool.queue_length, 2)
    execute_one()
    self.assertEqual(pool.queue_length, 1)
    execute_one()
    self.assertEqual(pool.queue_length, 0)

  @mock.patch("concurrent.futures.ThreadPoolExecutor")
  def test_flush(self, executor_mock):
    pool_mock = mock.Mock()
    pool_mock._in_flight = None

    def execute_one():
      pool_mock._in_flight.pop(0)()

    def submit(fn, *args, **kwargs):
      pool_mock._in_flight.append(lambda: fn(*args, **kwargs))

    def create_pool(max_workers, thread_name_prefix):
      del max_workers
      del thread_name_prefix
      pool_mock._in_flight = []
      return pool_mock

    def shutdown(wait=False):
      if wait:
        while pool_mock._in_flight:
          execute_one()
      pool_mock._in_flight = None

    pool_mock.submit = submit
    executor_mock.side_effect = create_pool
    pool_mock.shutdown.side_effect = shutdown

    pool = asynclib.Pool()

    @pool
    def noop():
      ...

    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)
    noop()
    pool.join()
    self.assertEqual(pool.queue_length, 0)
    noop()
    self.assertEqual(pool.queue_length, 1)


if __name__ == "__main__":
  absltest.main()


================================================
FILE: clu/checkpoint.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple checkpointing library for TF2/Flax.

The class `Checkpoint` is a simple wrapper around `tf.train.Checkpoint` that
also stores a `flax.struct.dataclass` instance in the same directory.

Synopsis:

  from clu import checkpoint
  import flax

  @flax.struct.dataclass
  class TrainState:
    optimizer: flax.optim.Optimizer
    step: int

  ds = load_tf_dataset()
  ds_iter = iter(ds)
  ckpt = checkpoint.MultihostCheckpoint(base_directory, dict(ds_iter=ds_iter))
  optimizer = create_flax_optimizer()
  state = TrainState(optimizer=optimizer, step=0)
  state = ckpt.restore_or_initialize(state)  # Also restores `ds_iter`.
  initial_step = int(state.step) + 1
  # Need to replicate all data when training with multiple accelerators.
  state = flax.jax_utils.replicate(state)

  for step in range(initial_step, steps + 1):
    state = update_step(state, next(ds_iter))
    ckpt.save(flax.jax_utils.unreplicate(state))

Loading the model e.g. in a Colab:

  from clu import checkpoint
  import flax
  from . import mnist_lib

  state_dict = checkpoint.load_state_dict(base_directory)
  params = state_dict['optimizer']['target']['params']
  module = mnist_lib.MyArchitecture.partial(num_classes=10)
  model = flax.deprecated.nn.Model(module, params)
"""

import collections
import os
import re
from typing import Any, Dict, Optional, TypeVar

from absl import logging

from clu.internal import utils
import flax
import jax
import tensorflow as tf

# TODO(b/200953513): Migrate away from logging imports (on module level)
#                    to logging the actual usage. See b/200953513.



T = TypeVar("T")
SCHEME_RE = re.compile("^(?P<scheme>[a-z][a-z0-9.+-]+://)?(?P<path>.*)", re.I)


def safe_normpath(path: str) -> str:
  """Normalizes path safely to get around `gfile.glob()` limitations."""
  d = SCHEME_RE.match(path).groupdict()  # pytype: disable=attribute-error  # re-none
  return (d["scheme"] or "") + os.path.normpath(d["path"])


def load_state_dict(base_directory) -> Dict[str, Any]:
  """Restores `state` as dictionary from the latest checkpoint.

  Synopsis:

    data = checkpoint.load_state_dict(base_directory)
    params = data['optimizer']['target']['params']
    module = mnist_lib.MyArchitecture.partial(num_classes=10)
    model = flax.deprecated.nn.Model(module, params)

  Args:
    base_directory: Directory from which the checkpoints should be restored. See
      `Checkpoint.__init__()`.

  Returns:
    The deserialized Flax data, as a dictionary.

  Raises:
    FileNotFoundError: If there is no checkpoint to restore.
  """
  return Checkpoint(base_directory).load_state(state=None)


class CheckpointInfo(
    collections.namedtuple("CheckpointInfo", ("prefix", "number"))):
  """Helper class to parse a TensorFlow checkpoint path."""

  CHECKPOINT_REGEX = r"^(?P<prefix>.*)-(?P<number>\d+)"

  @classmethod
  def initialize(cls, base_directory, checkpoint_name: str) -> "CheckpointInfo":
    """Creates a first CheckpointInfo (number=1)."""
    return cls(f"{base_directory}/{checkpoint_name}", 1)

  @classmethod
  def from_path(cls, checkpoint: str) -> "CheckpointInfo":
    """Parses a checkpoint.

    Args:
      checkpoint: A checkpoint prefix, as can be found in the
        `.latest_checkpoint` property of a `tf.train.CheckpointManager`.

    Returns:
      An instance of `CheckpointInfo` that represents `checkpoint`.
    """
    m = re.match(cls.CHECKPOINT_REGEX, checkpoint)
    if m is None:
      RuntimeError(f"Invalid checkpoint format: {checkpoint}")
    d = m.groupdict()  # pytype: disable=attribute-error
    return cls(d["prefix"], int(d["number"]))

  def increment(self) -> "CheckpointInfo":
    """Returns a new CheckpointInfo with `number` increased by one."""
    return CheckpointInfo(self.prefix, self.number + 1)

  def __str__(self):
    """Does the opposite of `.from_path()`."""
    return f"{self.prefix}-{self.number}"


class Checkpoint:
  """A utility class for storing and loading TF2/Flax checkpoints.

  Both the state of a `tf.data.Dataset` iterator and a `flax.struct.dataclass`
  are stored on disk in the following files:

  - {directory}/checkpoint
  - {directory}/ckpt-{number}.index
  - {directory}/ckpt-{number}.data@*
  - {directory}/ckpt-{number}.flax

  Where {number} starts at 1 is then incremented by 1 for every new checkpoint.
  The last file is the `flax.struct.dataclass`, serialized in Messagepack
  format. The other files are explained in more detail in the Tensorflow
  documentation:

  https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint
  """

  def __init__(self,
               base_directory: str,
               tf_state: Optional[Dict[str, Any]] = None,
               *,
               max_to_keep: int = 5,
               checkpoint_name: str = "ckpt"):
    """Initializes a Checkpoint with a dictionary of TensorFlow Trackables.

    Args:
      base_directory: Directory under which the checkpoints will be stored. Use
        a different base_directory in every task.
      tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
        example a dataset iterator.
      max_to_keep: Number of checkpoints to keep in the directory. If there are
        more checkpoints than specified by this number, then the oldest
        checkpoints are removed.
      checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
    """
    if tf_state is None:
      tf_state = dict()
    base_directory = safe_normpath(base_directory)
    self.base_directory = base_directory
    self.max_to_keep = max_to_keep
    self.checkpoint_name = checkpoint_name
    self.tf_checkpoint = tf.train.Checkpoint(**tf_state)
    self.tf_checkpoint_manager = tf.train.CheckpointManager(
        self.tf_checkpoint,
        base_directory,
        max_to_keep=max_to_keep,
        checkpoint_name=checkpoint_name)
    self.restored_from = None

  def get_latest_checkpoint_to_restore_from(self):
    """Returns the latest checkpoint to restore from.

    In the current implementation, this method simply returns the attribute
    `latest_checkpoint`.

    Subclasses can override this method to provide an alternative checkpoint to
    restore from, for example for synchronization across multiple checkpoint
    directories.
    """
    return self.latest_checkpoint

  @property
  def latest_checkpoint(self) -> Optional[str]:
    """Latest checkpoint, see `tf.train.CheckpointManager.latest_checkpoint`.

    Returns:
      A string to the latest checkpoint. Note that this string is path-like but
      it does not really describe a file, but rather a set of files that are
      constructed from this string, by appending different file extensions. The
      returned value is `None` if there is no previously stored checkpoint in
      `base_directory` specified to `__init__()`.
    """
    return self.tf_checkpoint_manager.latest_checkpoint

  @property
  def current_checkpoint(self) -> Optional[str]:
    """Returns current checkpoint.

    Note that after instance creation this will point to "ckpt-0", which does
    not actually exist. After the first save (either via `.save()` or via
    `.restore_or_initialize()`) it will point to "ckpt-1". When the checkpoint
    is loaded from a specific checkpoint (via `.restore(state, checkpoint)`)
    then this property can be different from `.latest_checkpoint`.

    Returns:
      A string refering to the current checkpoint. See `.latest_checkpoint` for
      a description of the format.
    """
    latest_checkpoint = self.latest_checkpoint
    if latest_checkpoint is None:
      return None
    checkpoint_info = CheckpointInfo.from_path(latest_checkpoint)
    number = self.tf_checkpoint.save_counter.numpy()
    return str(checkpoint_info._replace(number=number))

  def _flax_path(self, checkpoint: str) -> str:
    return "{}.flax".format(checkpoint)

  def _next_checkpoint(self, checkpoint: Optional[str]) -> str:
    if checkpoint is None:
      return str(
          CheckpointInfo.initialize(self.base_directory, self.checkpoint_name))
    return str(CheckpointInfo.from_path(checkpoint).increment())

  def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]:
    if checkpoint is None:
      return None
    return CheckpointInfo.from_path(checkpoint).number

  def _delete_future_checkpoints(self):
    """Deletes checkpoints that are newer than the currently loaded checkpoint.

    This happens when the checkpoint was initialized from a checkpoint that was
    not the latest checkpoint (e.g. when recovering from a pre-emption in a
    `MultihostCheckpoint` where some workers finished writing their checkpoints
    and others didn't).
    """
    checkpoint = self.current_checkpoint
    while True:
      checkpoint = self._next_checkpoint(checkpoint)
      paths = tf.io.gfile.glob(f"{checkpoint}.*")
      if not paths:
        break
      for path in paths:
        logging.info("Cleaning up future checkpoint file '%s'", path)
        tf.io.gfile.remove(path)

  @utils.logged_with("Checkpoint.save()")
  def save(self, state) -> str:
    """Saves a new checkpoints in the directory.

    Note that if the checkpoint was restored from an earlier checkpoint than the
    latest available, then saving the checkpoint will and/or delete any
    checkpoints later than the restored one.

    For example, if there are checkpoints `(1, 2, 3)` and then checkpoint `1`
    is restored, then calling `.save()` on that restored checkpoint will result
    in `2` being overwritten and `3` being deleted.

    This overwriting/deleting behavior allows for seamless integration with
    `MultihostCheckpoint` after pre-emption (i.e. one of the workers might have
    stored one more checkpoint, but that checkpoint is only available on that
    one worker and must be overwritten when the training continues).

    After such an overwrite, the attributes `.current_checkpoint` and
    `.latest_checkpoint` will point to newly written checkpoint (in above case
    `2`), but the list `.tf_checkpoint_manager.checkpoints` might be out of sync
    and should not be used.

    Args:
      state: Flax checkpoint to be stored.

    Returns:
      The checkpoint identifier ({base_directory}/ckpt-{number}).
    """
    self._delete_future_checkpoints()

    next_checkpoint = self._next_checkpoint(self.current_checkpoint)
    flax_path = self._flax_path(next_checkpoint)
    logging.info("Storing next checkpoint '%s'", next_checkpoint)

    if not tf.io.gfile.exists(self.base_directory):
      tf.io.gfile.makedirs(self.base_directory)
    with tf.io.gfile.GFile(flax_path, "wb") as f:
      f.write(flax.serialization.to_bytes(state))

    checkpoints_before_save = set(self.tf_checkpoint_manager.checkpoints)
    # Write Tensorflow data last. This way Tensorflow checkpoint generation
    # logic will make sure to only commit checkpoints if they complete
    # successfully. A previously written `flax_path` would then simply be
    # overwritten next time.
    self.tf_checkpoint_manager.save()
    # Clean up stale Flax. Tensorflow automatically does remove checkpoints
    # older than `max_to_keep`, so we do the same for the Flax checkpoints.
    stale_checkpoints = checkpoints_before_save - set(
        self.tf_checkpoint_manager.checkpoints)
    for checkpoint in stale_checkpoints:
      if tf.io.gfile.exists(self._flax_path(checkpoint)):
        tf.io.gfile.remove(self._flax_path(checkpoint))
    assert self.current_checkpoint == next_checkpoint, (
        "Expected next_checkpoint to match .current_checkpoint: "
        f"{next_checkpoint} != {self.current_checkpoint}")
    return self.current_checkpoint

  @utils.logged_with("Checkpoint.restore_or_initialize()")
  def restore_or_initialize(self, state: T) -> T:
    """Restores from the latest checkpoint, or creates a first checkpoint.

    Args:
      state : A data structure to be stored or to serve as a template. If the
        checkpoint is restored (and not initialized), then the fields of `state`
        must match the data previously stored. See
        `flax.serialization.from_state_dict()` for details.

    Returns:
      The restored `state` object. Note that all TensorFlow `Trackable`s in
      `tf_state` (see `__init__()`) are also updated.
    """
    checkpoint = self.get_latest_checkpoint_to_restore_from()
    if checkpoint is not None:
      return self.restore(state, checkpoint)
    logging.info("Storing initial version.")
    self.save(state)
    return state

  def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, Any]:
    """Restores last checkpoint and returns `state` as dictionary.

    The only difference between this method and `.restore()` is the return type
    annotation.

    Args:
      checkpoint: Checkpoint name that should be restored. Defaults to latest
        available checkpoint. See `.latest_checkpoint` for a description of the
        format of this string.

    Returns:
      The restored `state` object. Note that all TensorFlow `Trackable`s in
      `tf_state` (see `__init__()`) are also updated.

    Raises:
      FileNotFoundError: If specified checkpoint does not exist, or if there
      is no checkpoint to restore in case no checkpoint was specified.
    """
    return self.restore(state=None, checkpoint=checkpoint)

  def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str:
    if checkpoint is None:
      checkpoint = self.get_latest_checkpoint_to_restore_from()
      if checkpoint is None:
        raise FileNotFoundError(f"No checkpoint found at {self.base_directory}")
    return checkpoint

  def load_state(self,
                 state: Optional[T],
                 checkpoint: Optional[str] = None) -> T:
    """Restores Flax state the latest checkpoint.

    As opposed to `.restore()`, this function only reads the Flax checkpint and
    does not read the (potentially very large) TensorFlow state.

    Args:
      state : Template data structure that will serve as a template for the
        returned state. If the loaded data does not match that template, then an
        exception is raised. It's also possible to specify `state=None`, in
        which case a dictionary will be returned. See
        `flax.serialization.from_state_dict()` for details.
      checkpoint: Checkpoint name that should be restored. Defaults to latest
        available checkpoint. See `.latest_checkpoint` for a description of the
        format of this string.

    Returns:
      The restored `state` object. Note that all TensorFlow `Trackable`s in
      `tf_state` (see `__init__()`) are also updated.

    Raises:
      FileNotFoundError: If specified checkpoint does not exist, or if there
      is no checkpoint to restore in case no checkpoint was specified.
    """
    flax_path = self._flax_path(self._checkpoint_or_latest(checkpoint))
    if not tf.io.gfile.exists(flax_path):
      raise FileNotFoundError(f"Checkpoint {checkpoint} does not exist")
    with tf.io.gfile.GFile(flax_path, "rb") as f:
      return flax.serialization.from_bytes(state, f.read())

  def restore(self,
              state: Optional[T],
              checkpoint: Optional[str] = None) -> T:
    """Restores from the latest checkpoint.

    Similar to `restore_or_initialize()`, but raises a `FileNotFoundError` if
    there is no checkpoint.

    Args:
      state : Template data structure that will serve as a template for the
        returned state. If the loaded data does not match that template, then an
        exception is raised. It's also possible to specify `state=None`, in
        which case a dictionary will be returned. See
        `flax.serialization.from_state_dict()` for details.
      checkpoint: Checkpoint name that should be restored. Defaults to latest
        available checkpoint. See `.latest_checkpoint` for a description of the
        format of this string.

    Returns:
      The restored `state` object. Note that all TensorFlow `Trackable`s in
      `tf_state` (see `__init__()`) are also updated.

    Raises:
      FileNotFoundError: If specified checkpoint does not exist, or if there
      is no checkpoint to restore in case no checkpoint was specified.
    """
    checkpoint = self._checkpoint_or_latest(checkpoint)
    logging.info("Restoring checkpoint: %s", checkpoint)
    state = self.load_state(state, checkpoint)
    self.tf_checkpoint.restore(checkpoint)

    logging.info("Restored save_counter=%d restored_checkpoint=%s",
                 self.tf_checkpoint.save_counter.numpy(),
                 checkpoint)
    self.restored_from = checkpoint
    return state


class MultihostCheckpoint(Checkpoint):
  """An subclass of `Checkpoint` that synchronizes between multiple JAX hosts.

  If the training split across multiple hosts, then the following race condition
  can occur : If a host is pre-empted while writing a checkpoint, then the other
  hosts will only be restarted with a small delay, and at that point they
  probably already have finished writing their checkpoint. Upon restart, the
  host that was interrupted while writing the checkpoint will load the latest
  fully written checkpoint, which will be out of sync with the other hosts that
  successfully wrote one more checkpoint.

  This class also allows to specify a `multihost_base_directory` that is
  identical for all hosts and will be used to drive a host-specific directory.
  """

  def __init__(self,
               multihost_base_directory: str,
               tf_state: Optional[Dict[str, Any]] = None,
               *,
               host_id: Optional[int] = None,
               max_to_keep: int = 5,
               checkpoint_name: str = "ckpt"):
    """Initializes a MultihostCheckpoint with a dict of TensorFlow Trackables.

    Args:
      multihost_base_directory: Directory that will be used to construct a
        host-specific `base_directory` under which the checkpoints will be
        stored. Usually a directory *within* the work unit's workdirectory (e.g.
        `f"{workdir}/checkpoints`). One directory per host will be created at
        the same level as this base directory labeled
        `f"{multihost_base_directory}-{host_id}"`.
      tf_state: A dictionary of TensorFlow `Trackable` to be serialized, for
        example a dataset iterator.
      host_id: Host ID used to construct the `base_directory`. Taken from
        `jax.process_index()` if not specified.
      max_to_keep: Number of checkpoints to keep in the directory. If there are
        more checkpoints than specified by this number, then the oldest
        checkpoints are removed.
      checkpoint_name: Prefix of the checkpoint files (before `-{number}`).
    """
    if max_to_keep < 2:
      raise ValueError("Requires multiple checkpoints (max_to_keep>=2).")
    multihost_base_directory = multihost_base_directory.rstrip("/")
    self.multihost_base_directory = multihost_base_directory
    if host_id is None:
      host_id = jax.process_index()
    base_directory = f"{multihost_base_directory}-{host_id}"
    super().__init__(
        base_directory,
        tf_state,
        max_to_keep=max_to_keep,
        checkpoint_name=checkpoint_name)

  @utils.logged_with(
      "MultihostCheckpoint.get_latest_checkpoint_to_restore_from()")
  def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:
    """Returns the latest checkpoint available on all hosts."""
    base_directory_glob = f"{self.multihost_base_directory}-*"
    base_directories = tf.io.gfile.glob(base_directory_glob)
    if self.base_directory not in base_directories:
      logging.info("%s not in %s", self.base_directory, base_directories)
      return None
    checkpoints = {}
    common_numbers = None
    all_numbers = set()
    for base_directory in base_directories:
      checkpoint_manager = tf.train.CheckpointManager(
          tf.train.Checkpoint(),
          base_directory,
          max_to_keep=self.max_to_keep,
          checkpoint_name=self.checkpoint_name)
      numbers = [
          CheckpointInfo.from_path(checkpoint).number
          for checkpoint in checkpoint_manager.checkpoints
      ]
      checkpoints[base_directory] = dict(
          zip(numbers, checkpoint_manager.checkpoints))
      numbers = set(numbers)
      if common_numbers is None:
        common_numbers = numbers
      else:
        common_numbers &= numbers
      all_numbers |= numbers
    logging.info(
        "Checked checkpoint base_directories: %s - common_numbers=%s "
        "- exclusive_numbers=%s", base_directories, common_numbers,
        all_numbers.difference(common_numbers))
    if not common_numbers:
      return None
    highest_number = sorted(common_numbers)[-1]
    return checkpoints[self.base_directory][highest_number]


================================================
FILE: clu/checkpoint_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for clu.checkpoint."""

import os
import tempfile
from unittest import mock

from clu import checkpoint
import flax
import tensorflow as tf


def _make_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  features = dict(x=inputs, y=labels)
  return tf.data.Dataset.from_tensor_slices(features).repeat().batch(2)


@flax.struct.dataclass
class TrainState:
  step: int


@flax.struct.dataclass
class TrainStateExtended:
  step: int
  name: str


class NotTrainState:
  pass


def _checkpoint_number(path):
  *parts, number = path.split("-")
  del parts
  return int(number)


class CheckpointTest(tf.test.TestCase):

  def test_safe_normpath(self):
    self.assertEqual(checkpoint.safe_normpath("./test_dir"), "test_dir")
    self.assertEqual(checkpoint.safe_normpath(".//test_dir"), "test_dir")
    self.assertEqual(checkpoint.safe_normpath("gs://test_dir"), "gs://test_dir")
    self.assertEqual(
        checkpoint.safe_normpath("gs://test_dir/"), "gs://test_dir")

  def test_initialize_mkdir(self):
    base_dir = os.path.join(tempfile.mkdtemp(), "test")
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir)
    self.assertIsNone(ckpt.current_checkpoint)
    self.assertIsNone(ckpt.latest_checkpoint)
    self.assertFalse(os.path.isdir(base_dir))
    state = ckpt.restore_or_initialize(state)
    self.assertIsNotNone(ckpt.latest_checkpoint)
    self.assertEqual(ckpt.latest_checkpoint, ckpt.current_checkpoint)
    self.assertTrue(os.path.isdir(base_dir))

  def test_restores_flax_state(self):
    base_dir = tempfile.mkdtemp()
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=2)
    # Initializes.
    state = ckpt.restore_or_initialize(state)
    state = TrainState(step=0)
    # Restores step=1.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 1)
    state = TrainState(step=2)
    # Stores step=2.
    path = ckpt.save(state)
    self.assertEqual(_checkpoint_number(path), 2)
    state = TrainState(step=0)
    # Restores step=2.
    state = ckpt.restore(state)
    self.assertEqual(state.step, 2)
    state = TrainState(step=3)
    # Stores step=3
    path2 = ckpt.save(state)
    self.assertEqual(_checkpoint_number(path2), 3)
    state = TrainState(step=0)
    # Restores step=2.
    state = ckpt.restore(state, path)
    self.assertEqual(state.step, 2)

  def test_load_state_dict(self):
    base_dir = tempfile.mkdtemp()
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir)
    # Initializes.
    state = ckpt.restore_or_initialize(state)
    # Load via load_state_dict().
    flax_dict = checkpoint.load_state_dict(base_dir)
    self.assertEqual(flax_dict, dict(step=1))
    with self.assertRaisesRegex(FileNotFoundError, r"^No checkpoint found"):
      checkpoint.load_state_dict(tempfile.mkdtemp())

  def test_fails_when_restoring_subset(self):
    base_dir = tempfile.mkdtemp()
    state = TrainStateExtended(step=1, name="test")
    ckpt = checkpoint.Checkpoint(base_dir)
    # Initialixes with TrainStateExtended.
    state = ckpt.restore_or_initialize(state)
    state = TrainState(step=0)
    # Restores with TrainState.
    with self.assertRaisesRegex(ValueError, r"^Unknown field"):
      state = ckpt.restore_or_initialize(state)

  def test_fails_when_restoring_superset(self):
    base_dir = tempfile.mkdtemp()
    ckpt = checkpoint.Checkpoint(base_dir)
    state = TrainState(step=0)
    # Initialixes with TrainState.
    state = ckpt.restore_or_initialize(state)
    state = TrainStateExtended(step=1, name="test")
    # Restores with TrainStateExtended.
    with self.assertRaisesRegex(ValueError, r"^Missing field"):
      state = ckpt.restore_or_initialize(state)

  def test_restores_tf_state(self):
    base_dir = tempfile.mkdtemp()
    ds_iter = iter(_make_dataset())
    ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
    features0 = next(ds_iter)  # Advance iterator by one.
    del features0
    state = TrainState(step=1)
    # Initialize at features1.
    state = ckpt.restore_or_initialize(state)
    features1 = next(ds_iter)
    features2 = next(ds_iter)
    self.assertNotAllEqual(features1["x"], features2["x"])
    self.assertNotAllEqual(features1["y"], features2["y"])
    # Restore at features1.
    state = ckpt.restore_or_initialize(state)
    features1_restored = next(ds_iter)
    self.assertAllEqual(features1["x"], features1_restored["x"])
    self.assertAllEqual(features1["y"], features1_restored["y"])
    # Save at features2.
    path = ckpt.save(state)
    self.assertEqual(_checkpoint_number(path), 2)
    features2 = next(ds_iter)
    features3 = next(ds_iter)
    self.assertNotAllEqual(features2["x"], features3["x"])
    self.assertNotAllEqual(features2["y"], features3["y"])
    # Restore at features2.
    state = ckpt.restore_or_initialize(state)
    features2_restored = next(ds_iter)
    self.assertAllEqual(features2["x"], features2_restored["x"])
    self.assertAllEqual(features2["y"], features2_restored["y"])
    # Restore at features2 as dictionary.
    state = ckpt.restore_dict()
    features2_restored = next(ds_iter)
    self.assertAllEqual(features2["x"], features2_restored["x"])
    self.assertAllEqual(features2["y"], features2_restored["y"])

  def test_restore_flax_alone(self):
    base_dir = tempfile.mkdtemp()
    ds_iter = iter(_make_dataset())
    ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
    state = TrainState(step=1)
    # Initializes.
    state = ckpt.restore_or_initialize(state)
    state = TrainState(step=0)
    ckpt = checkpoint.Checkpoint(base_dir)
    # Restores step=1.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 1)

  def test_restore_dict(self):
    base_dir = tempfile.mkdtemp()
    ds_iter = iter(_make_dataset())
    ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
    with self.assertRaisesRegex(FileNotFoundError, r"No checkpoint found at"):
      ckpt.restore_dict()
    with self.assertRaisesRegex(FileNotFoundError,
                                r"Checkpoint invalid does not exist"):
      ckpt.restore_dict(checkpoint="invalid")

    state = TrainState(step=1)
    ckpt.save(state)

    state_dict = ckpt.restore_dict()
    self.assertEqual(state_dict, dict(step=1))
    first_checkpoint = ckpt.latest_checkpoint

    new_state = TrainState(step=2)
    ckpt.save(new_state)

    self.assertEqual(
        ckpt.restore_dict(checkpoint=first_checkpoint),
        dict(step=1))
    self.assertEqual(ckpt.restore_dict(), dict(step=2))
    self.assertEqual(
        ckpt.restore_dict(checkpoint=ckpt.latest_checkpoint),
        dict(step=2))

  def test_ignores_incomplete_checkpoint(self):
    base_dir = tempfile.mkdtemp()
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir)
    # Initializes.
    state = ckpt.restore_or_initialize(state)
    state = TrainState(step=0)
    # Restores step=1.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 1)
    state = TrainState(step=2)
    # Failed save : step=2 is stored, but TensorFlow checkpoint fails.
    ckpt.tf_checkpoint_manager.save = None
    with self.assertRaisesRegex(TypeError,
                                r"'NoneType' object is not callable"):
      ckpt.save(state)
    files = os.listdir(base_dir)
    self.assertIn("ckpt-2.flax", files)
    self.assertNotIn("ckpt-2.index", files)
    ckpt = checkpoint.Checkpoint(base_dir)
    state = TrainState(step=0)
    # Restores step=1.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 1)
    # Stores step=2.
    state = TrainState(step=2)
    path = ckpt.save(state)
    self.assertEqual(_checkpoint_number(path), 2)
    files = os.listdir(base_dir)
    self.assertIn("ckpt-2.flax", files)
    self.assertIn("ckpt-2.index", files)
    state = TrainState(step=0)
    # Restores step=2.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 2)

  def test_max_to_keep(self):
    base_dir = tempfile.mkdtemp()
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
    state = ckpt.restore_or_initialize(state)
    files1 = os.listdir(base_dir)
    state = TrainState(step=2)
    path = ckpt.save(state)
    self.assertEqual(_checkpoint_number(path), 2)
    files2 = os.listdir(base_dir)
    self.assertEqual(len(files1), len(files2))
    self.assertNotEqual(files1, files2)

  def test_checkpoint_name(self):
    base_dir = tempfile.mkdtemp()
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir, checkpoint_name="test")
    path = ckpt.save(state)
    self.assertIn("test", path)

  def test_fails_if_not_registered(self):
    base_dir = tempfile.mkdtemp()
    not_state = NotTrainState()
    ckpt = checkpoint.Checkpoint(base_dir)
    with self.assertRaisesRegex(TypeError, r"serialize"):
      ckpt.restore_or_initialize(not_state)

  def test_overwrite(self):
    base_dir = tempfile.mkdtemp()
    tf_step = tf.Variable(1)
    state = TrainState(step=1)
    ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
    # Initialize step=1.
    state = ckpt.restore_or_initialize(state)
    self.assertEqual(state.step, 1)
    self.assertEqual(tf_step.numpy(), 1)
    checkpoint_info = checkpoint.CheckpointInfo.from_path(
        ckpt.current_checkpoint)
    # Stores steps 2, 3, 4, 5
    for _ in range(4):
      tf_step.assign_add(1)
      state = state.replace(step=state.step + 1)
      ckpt.save(state)
    latest_checkpoint = str(checkpoint_info._replace(number=5))
    self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
    self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
    # Restores at step=1
    ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
    state = ckpt.restore(state, checkpoint=str(checkpoint_info))
    self.assertEqual(state.step, 1)
    self.assertEqual(tf_step.numpy(), 1)
    self.assertNotEqual(ckpt.current_checkpoint, ckpt.latest_checkpoint)
    self.assertEqual(ckpt.current_checkpoint, str(checkpoint_info))
    self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
    # Overwrites step=2, deletes 3, 4, 5.
    tf_step.assign_add(1)
    state = state.replace(step=state.step + 1)
    ckpt.save(state)
    latest_checkpoint = str(checkpoint_info._replace(number=2))
    self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
    self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)


class MultihostCheckpoint(tf.test.TestCase):

  @mock.patch("jax.process_index")
  def test_initialize_mkdir(self, process_index_mock):
    multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
    state = TrainState(step=1)
    process_index_mock.return_value = 0
    base_dir = f"{multihost_base_dir}-0"
    ckpt = checkpoint.MultihostCheckpoint(multihost_base_dir)
    self.assertIsNone(ckpt.latest_checkpoint)
    self.assertFalse(os.path.isdir(base_dir))
    state = ckpt.restore_or_initialize(state)
    self.assertIsNotNone(ckpt.latest_checkpoint)
    self.assertTrue(os.path.isdir(base_dir))

  @mock.patch("jax.process_index")
  def test_synchronize_multiple_hosts(self, process_index_mock):
    multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
    state = TrainState(step=1)
    process_index_mock.return_value = 0
    ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
    process_index_mock.return_value = 1
    ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
    # Initialize both at step=1.
    state_0 = ckpt_0.restore_or_initialize(state)
    state_1 = ckpt_1.restore_or_initialize(state)
    # Update both at step=2.
    state_0 = state_0.replace(step=2)
    ckpt_0.save(state_0)
    state_1 = state_1.replace(step=2)
    ckpt_1.save(state_1)
    # Update ckpt_1 at step=3.
    state_1 = state_1.replace(step=3)
    ckpt_1.save(state_1)
    # Reload both at step=2.
    process_index_mock.return_value = 0
    ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir)
    process_index_mock.return_value = 1
    ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir)
    self.assertEqual(ckpt_0.latest_checkpoint,
                     ckpt_0.get_latest_checkpoint_to_restore_from())
    self.assertNotEqual(ckpt_1.latest_checkpoint,
                        ckpt_1.get_latest_checkpoint_to_restore_from())
    state_0 = ckpt_0.restore_or_initialize(state)
    state_1 = ckpt_1.restore_or_initialize(state)
    self.assertEqual(state_0.step, 2)
    self.assertEqual(state_1.step, 2)

  def test_preemption(self):
    multihost_base_dir = os.path.join(tempfile.mkdtemp(), "test")
    state = TrainState(step=1)
    state0 = state.replace(step=0)
    ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
    ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
    # Initialize both at step=1.
    state_0 = ckpt_0.restore_or_initialize(state)
    state_1 = ckpt_1.restore_or_initialize(state)
    self.assertEqual(state_0.step, 1)
    self.assertEqual(state_1.step, 1)
    # Restore both at step=1.
    state_0 = ckpt_0.restore_or_initialize(state0)
    state_1 = ckpt_1.restore_or_initialize(state0)
    self.assertEqual(state_0.step, 1)
    self.assertEqual(state_1.step, 1)
    # Update only ckpt_0 to step=2.
    state_0 = state_0.replace(step=2)
    ckpt_0.save(state_0)
    # Load both checkpoints at last common step=1.
    ckpt_0 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=0)
    ckpt_1 = checkpoint.MultihostCheckpoint(multihost_base_dir, host_id=1)
    state_0 = ckpt_0.restore_or_initialize(state)
    state_1 = ckpt_1.restore_or_initialize(state)
    self.assertEqual(state_0.step, 1)
    self.assertEqual(state_1.step, 1)
    # Store both at step=2.
    state_0 = state_0.replace(step=2)
    state_1 = state_1.replace(step=2)
    ckpt_0.save(state_0)
    ckpt_1.save(state_1)
    # Restore both at step=2.
    state_0 = ckpt_0.restore_or_initialize(state0)
    state_1 = ckpt_1.restore_or_initialize(state0)
    self.assertEqual(state_0.step, 2)
    self.assertEqual(state_1.step, 2)

if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/data/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DatasetIterator is an interface for input pipelines."""
# pylint: disable=g-multiple-import
# pylint: disable=unused-import

from clu.data.dataset_iterator import (
    Array,
    ArraySpec,
    DatasetIterator,
    Element,
    ElementSpec,
    TfDatasetIterator,
    PeekableDatasetIterator,
    PyTree,
)


================================================
FILE: clu/data/dataset_iterator.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Interface for dataset iterators.

This module provides the DatasetIterator interface. This intention is that
several frameworks providing datasets can implement this interface without
knowing anything about the framework used for the model and the training loop.
Likewise can training loops assume to get an DatasetIterator object and do not
need to care about the specifics of the input pipelines.

This modules does not depend on TensorFlow. The interface is generic and users
don't have to use `tf.data` to construct a DatasetIterator. However, if they
use `tf.data` they can simply wrap their `tf.data.Dataset` object with
`TfDatasetIterator` to satisfy the interface.
"""
from __future__ import annotations

import abc
import collections.abc
import concurrent.futures
import dataclasses
import os
import threading
import typing
from typing import Any, Mapping, Optional, Sequence, Tuple, TypeVar, Union

from absl import logging
from clu import asynclib
from etils import epath
import jax.numpy as jnp  # Just for type checking.
import numpy as np
import numpy.typing as npt

Array = Union[np.ndarray, jnp.ndarray]
# Sizes of dimensions, None means the dimension size is unknown.
Shape = Tuple[Optional[int], ...]


@dataclasses.dataclass(frozen=True)
class ArraySpec:
  """Describes an array via it's dtype and shape."""
  dtype: npt.DTypeLike
  shape: Shape

  def __repr__(self):
    return f"ArraySpec(dtype={np.dtype(self.dtype).name}, shape={self.shape})"

  def __str__(self):
    return f"{np.dtype(self.dtype).name}{list(self.shape)}"


# Elements are PyTrees with NumPy/JAX arrays.

# Anything can be a PyTree (it's either a container or leaf). We define
# PyTree[T] as a PyTree where all leaves are of type T.
# See https://jax.readthedocs.io/en/latest/pytrees.html.
L = TypeVar("L")  # pylint: disable=invalid-name

PyTree = Union[L, Sequence["PyTree[L]"], Mapping[str, "PyTree[L]"]]

Element = PyTree[Array]
ElementSpec = PyTree[ArraySpec]


class DatasetIterator(collections.abc.Iterator):  # pytype: disable=ignored-abstractmethod
  """Generic interface for iterating over a dataset.

  This does not support __getitem__ since it cannot be implemented efficiently
  for many datasets. However datasets should allow starting the iterator from
  an arbitrary position.

  The element_spec property helps consumers to validate the input without
  reading data. This is similar to `tf.data.Dataset.element_spec`.

  Subclasses may decided to not read/write checkpoints if their state is
  sufficiently tracked externally (e.g. input pipelines that can be correctly
  restarted from the step number).
  """

  def get_next(self) -> Element:
    """Returns the next element."""
    logging.error(
        "DatasetIterator.get_next() is deprecated. Please use next().")
    # Subclasses should implement __next__() and remove calls to get_next().
    return next(self)

  def reset(self):
    """Resets the iterator back to the beginning."""
    raise NotImplementedError

  @property
  @abc.abstractmethod
  def element_spec(self) -> ElementSpec:
    """Returns the spec elements."""
    raise NotImplementedError()

  def save(self, filename: epath.Path):
    """Saves the state of the iterator to a file.

    This should only handle this iterator - not iterators in other processes.

    Args:
      filename: Name of the checkpoint.
    """
    raise NotImplementedError()

  def restore(self, filename: epath.Path):
    """Restores the iterator from a file (if available).

    This should only handle this iterator - not iterators in other processes.

    Args:
      filename: Name of the checkpoint.
    """
    raise NotImplementedError()

  def load(self, filename: epath.Path):
    logging.error("DatasetIterator.load() is deprecated. Please use restore().")
    return self.restore(filename)


class TfDatasetIterator(DatasetIterator):
  """DatasetIterator for wrapping a `tf.data.Dataset`."""

  def __init__(self, dataset, *, checkpoint: bool):
    """Wraps `tf.data.Dataset` object into the `DatasetIterator` interface.

    Warning: Do not wrap this interator to do asynchronous prefetching if you
    use `checkpoint=True` (default). tf.data iterators must be saved()
    synchronously.

    Args:
      dataset: The dataset to wrap. Elements are converted to NumPy arrays but
        no additional prefetching is done. tf.data should automatically prefetch
        elements (to CPU memory).
      checkpoint: Whether to checkpoint the dataset iterator object.
        Checkpointing dataset iterators is required for handling job
        pre-emptions but depending on your input pipeline can result in very
        large checkpoints. If set to False save() and load() are no-ops.
    """
    try:
      # Since this is the only class in this module using TF we only import
      # tensorflow if needed.
      if typing.TYPE_CHECKING:
        tf = Any
      else:
        import tensorflow as tf  # pylint: disable=g-import-not-at-top
    except ImportError as e:
      raise RuntimeError("When using TfDatasetIterator your binary must "
                         "depend on //third_party/py/tensorflow.") from e
    self._tf = tf

    if not isinstance(dataset, tf.data.Dataset):
      raise ValueError("`dataset` must be an instance of `tf.data.Dataset` "
                       f"but got {type(dataset)}.")
    self._dataset = dataset
    self._checkpoint = checkpoint
    assert self.element_spec  # Verify element spec.
    self.iterator = iter(dataset)
    self._ckpt = tf.train.Checkpoint(ds=self.iterator)

  def get_next(self) -> Element:
    return next(self)

  def __next__(self) -> Element:
    return {k: np.asarray(v) for k, v in next(self.iterator).items()}

  def reset(self):
    self.iterator = iter(self._dataset)
    self._ckpt = self._tf.train.Checkpoint(ds=self.iterator)

  @property
  def element_spec(self) -> ElementSpec:
    element_spec = self._dataset.element_spec
    if not isinstance(element_spec, dict):
      raise ValueError("Dataset elements must be flat dictionaries but got "
                       f"{element_spec}.")
    invalid_features = [
        k for k, v in element_spec.items()
        if not isinstance(v, self._tf.TensorSpec)
    ]
    if invalid_features:
      raise ValueError(f"Features {invalid_features} are not tensors. Dataset "
                       "elements must be flat dictionaries of tensors.")
    return {
        k: ArraySpec(dtype=v.dtype.as_numpy_dtype, shape=tuple(v.shape))
        for k, v in element_spec.items()
    }

  def save(self, filename: epath.Path):
    if self._checkpoint:
      self._ckpt.write(os.fspath(filename))

  def restore(self, filename: epath.Path):
    if self._checkpoint:
      self._ckpt.read(os.fspath(filename)).assert_consumed()


class PeekableDatasetIterator(DatasetIterator):
  """Wraps a DatasetIterator to provide a peek() method.

  This allows to look at the next element which can be useful in 2 scenarios:
  a) Get the structure of elements if the element_spec property is not
     supported.
  b) Request the next element without consuming it. This is especially handy to
     trigger reading of the first element while the model is being initialized.

  Example use case:
  >>> pool = clu.asynclib.Pool()
  >>> @pool
  >>> def warmup_input_pipeline():
  >>>   train_iter.peek()
  >>> first_batch_ready = warmup_input_pipeline()
  >>> # Do other stuff...
  >>> first_batch_ready.result()  # wait for input pipeline to be ready.
  """

  def __init__(self, it: DatasetIterator):
    self._it = it
    # Mutex for self._it.
    self._mutex = threading.Lock()
    self._peek: Optional[Element] = None
    self._pool = None
    self._peek_future = None

  def __next__(self) -> Element:
    with self._mutex:
      if self._peek is None:
        return next(self._it)
      peek = self._peek
      self._peek = None
      return peek

  def reset(self):
    with self._mutex:
      self._it.reset()
      self._peek = None
      self._pool = None
      self._peek_future = None

  @property
  def element_spec(self) -> ElementSpec:
    return self._it.element_spec

  def peek(self) -> Element:
    """Returns the next element without consuming it.

    This will get the next element from the underlying iterator. The element
    is stored and return on the next call of __next__().

    Returns:
      The next element.
    """
    if self._peek is None:
      self._peek = next(self)
    return self._peek

  def peek_async(self) -> concurrent.futures.Future[Element]:
    """Same as peek() but returns the Future of the element.

    Users can call this to warm up the iterator.

    Returns:
      Future with the next element. The element is also kept and returned on the
      next call of __next__().
    """
    with self._mutex:
      if self._peek_future is None:
        if self._pool is None:
          self._pool = asynclib.Pool(max_workers=1)
        self._peek_future = self._pool(self.peek)()
      return self._peek_future

  def save(self, filename: epath.Path):
    with self._mutex:
      self._it.save(filename)

  def restore(self, filename: epath.Path):
    with self._mutex:
      self._it.restore(filename)


================================================
FILE: clu/data/dataset_iterator_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for dataset_iterator."""
import itertools
import pathlib
import tempfile

from absl.testing import parameterized
from clu.data import dataset_iterator
import numpy as np
import tensorflow as tf

INDEX = "_index"


class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):

  def _create_iterator(self, start_index: int, checkpoint: bool = True):
    """Create an iterator over some prime numbers with index."""
    primes = tf.constant([2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
    ds = tf.data.Dataset.range(start_index, 10)
    ds = ds.map(lambda i: {INDEX: i, "prime": primes[i]})
    # Remove index 1 and 3.
    ds = ds.filter(lambda x: tf.logical_and(x["prime"] != 3, x["prime"] != 7))
    ds = ds.batch(2, drop_remainder=True)
    return dataset_iterator.TfDatasetIterator(ds, checkpoint=checkpoint)

  def test_tf_iterator(self):
    it = self._create_iterator(0)
    self.assertEqual(
        it.element_spec, {
            INDEX: dataset_iterator.ArraySpec(np.int64, (2,)),
            "prime": dataset_iterator.ArraySpec(np.int32, (2,))
        })
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
    it.reset()
    # Iterator starts from the beginning.
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})

  def test_tf_iterator_save_and_load(self):
    it = self._create_iterator(0)
    next(it)
    next(it)
    next(it)
    work_dir = pathlib.Path(tempfile.mkdtemp())
    filename = work_dir / "ckpt"
    it.save(filename)
    self.assertTrue((work_dir / "ckpt.index").exists())

    it = self._create_iterator(0)
    # Iterator is at the beginning (batch 1).
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    it.load(filename)
    # Iterator is at the end (batch 4).
    self.assertEqual(next(it), {INDEX: [8, 9], "prime": [23, 29]})

  def test_tf_iterator_save_and_load_no_checkpoint(self):
    it = self._create_iterator(0, checkpoint=False)
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})
    work_dir = pathlib.Path(tempfile.mkdtemp())
    filename = work_dir / "ckpt"
    it.save(filename)  # Should be a no-op and not create a checkpoint.
    self.assertFalse((work_dir / "ckpt.index").exists())

    it = self._create_iterator(0, checkpoint=False)
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    it.restore(filename)  # Should be a no-op, iterator just continues.
    self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})

  def test_peekable_dataset_iterator(self):
    it = self._create_iterator(0)
    it = dataset_iterator.PeekableDatasetIterator(it)
    self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})

  @parameterized.parameters(itertools.product([True, False], [True, False]))
  def test_peekable_dataset_iterator_async(self, wait: bool, peek_first: bool):
    it = self._create_iterator(0)
    it = dataset_iterator.PeekableDatasetIterator(it)
    future = it.peek_async()
    self.assertIsNone(it._peek)
    if wait:
      future.result()
      self.assertIsNotNone(it._peek)
    if peek_first:
      self.assertEqual(it.peek(), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [0, 2], "prime": [2, 5]})
    self.assertEqual(next(it), {INDEX: [4, 5], "prime": [11, 13]})


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/deterministic_data.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Helper functions for building deterministic tf.data input pipelines.

The function `create_dataset()` makes it easy to build a `tf.data` based input
pipeline that allows for completely reproducible results based on a single
initial random seed. The caller must take care to create a unique initial seed
on every host that is then passed to `create_dataset()`, where further unique
random keys are derived for every batch. Within a single batch, this key is
exposed as the special feature "rng" and can be used to implement stateless
preprocessing functions.

The function `get_read_instruction_for_host()` makes it easy to split a dataset
evenly between multiple hosts in a SPMD setup with multiple machines. Within a
single host, every batch is usually distributed to all the attached accelerators
(the first value of the `batch_dims` argument to `create_dataset()`).

The function `create_distributed_dataset()` finally is intended to be used in
conjunction with a `tf.distribute.Strategy`.

Synopsis for deterministic training with multiple hosts:

  import jax
  from clu import deterministic_data

  rng = jax.random.PRNGKey(42)  # Global RNG (e.g. from config)
  rng = jax.random.fold_in(rng, jax.process_index()) # Derive RNG for this host.
  dataset_builder = tfds.builder(...)
  split = deterministic_data.get_read_instruction_for_host(
      "train", dataset_builder.info.splits["train"].num_examples)
  ds = deterministic_data.create_dataset(
      dataset_builder,
      split=split,
      rng=rng
  )
  ds_iter = iter(ds)
  for _ in range(num_train_steps):
    batch = jax.tree_util.tree_map(lambda x: x._numpy(), next(ds_iter)
    # (training step)
"""

import enum
import functools
import operator
from typing import Callable, Dict, Optional, Sequence, Union

from absl import logging

import jax
import jax.numpy as jnp
import numpy as np
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds
import typing_extensions

# TODO(b/200953513): Migrate away from logging imports (on module level)
#                    to logging the actual usage. See b/200953513.


Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
Features = Dict[str, Tensor]

AUTOTUNE = tf.data.experimental.AUTOTUNE

_use_split_info = version.parse("4.4.0") < version.parse(
    tfds.version.__version__)


class DatasetBuilder(typing_extensions.Protocol):
  """Protocol for dataset builders (subset of tfds.core.DatasetBuilder)."""

  def as_dataset(
      self, split: Union[str, tfds.core.ReadInstruction], shuffle_files: bool,
      read_config: tfds.ReadConfig,
      decoders: Optional[Dict[str, tfds.decode.Decoder]]) -> tf.data.Dataset:
    ...


class RemainderOptions(enum.Enum):
  """How to handle examples not divisible by number of processes.

  Possible values:

  - DROP: Examples not divisible by process count will be dropped. Every host
    receives the same number of examples.
  - BALANCE_ON_PROCESSES: Examples not divisible by process count will be
    distributed evenly on processes, by increasing process number. For example,
    if there are 4 processes and 7 examples, then processes 0, 1, 2 will have
    2 examples, and process 3 will have 1 example.
  - ON_FIRST_PROCESS: Examples not divisible by process count will be assigned
    to process 0.
  """
  DROP = 0
  BALANCE_ON_PROCESSES = 1
  ON_FIRST_PROCESS = 2


def _shard_read_instruction(
    absolute_instruction,
    *,
    split_infos: Dict[str, Union[int, tfds.core.SplitInfo]],
    host_id: int,
    host_count: int,
    remainder_options: RemainderOptions,
) -> tfds.core.ReadInstruction:
  """Shards a single ReadInstruction. See get_read_instruction_for_host()."""
  start = absolute_instruction.from_ or 0
  if _use_split_info:
    end = absolute_instruction.to or (
        split_infos[absolute_instruction.splitname].num_examples)  # pytype: disable=attribute-error
  else:
    end = absolute_instruction.to or split_infos[absolute_instruction.splitname]
  assert end >= start, f"start={start}, end={end}"
  num_examples = end - start

  examples_per_host = num_examples // host_count
  shard_start = start + examples_per_host * host_id
  shard_end = start + examples_per_host * (host_id + 1)

  # Handle remaining examples.
  num_unused_examples = num_examples % host_count
  assert num_unused_examples >= 0, num_unused_examples
  assert num_unused_examples < host_count, num_unused_examples
  if num_unused_examples > 0:
    if remainder_options == RemainderOptions.DROP:
      logging.warning("Dropping %d examples of %d examples (host count: %d).",
                      num_unused_examples, num_examples, host_count)
    elif remainder_options == RemainderOptions.BALANCE_ON_PROCESSES:
      shard_start += min(host_id, num_unused_examples)
      shard_end += min(host_id + 1, num_unused_examples)
    elif remainder_options == RemainderOptions.ON_FIRST_PROCESS:
      shard_end += num_unused_examples
      if host_id > 0:
        shard_start += num_unused_examples
    else:
      raise ValueError(f"Invalid remainder_options: {remainder_options}")

  return tfds.core.ReadInstruction(
      absolute_instruction.splitname,
      from_=shard_start,
      to=shard_end,
      unit="abs")


_DEPRECATE_MSG = """
`get_read_instruction_for_host` is DEPRECATED.

Migration instruction: Use `tfds.split_for_jax_process` which is simpler
and nativelly supported by TFDS.

```
split = tfds.split_for_jax_process('train[75%:]', drop_remainder=True)

ds = tfds.load('my_dataset', split=split)
```

See: https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training

"""


def get_read_instruction_for_host(
    split: str,
    num_examples: Optional[int] = None,
    *,
    dataset_info: Optional[tfds.core.DatasetInfo] = None,
    host_id: Optional[int] = None,
    host_count: Optional[int] = None,
    drop_remainder: Optional[bool] = None,
    remainder_options: RemainderOptions = RemainderOptions.DROP
) -> tfds.core.ReadInstruction:
  """Returns a `ReadInstruction` of the data ranges for this host.

  `get_read_instruction_for_host` is DEPRECATED. Please use
  `tfds.split_for_jax_process` or `tfds.even_split`. See:
  https://www.tensorflow.org/datasets/splits#tfdseven_splits_multi-host_training

  In a distributed setting all hosts should get the same number of examples.
  This can exclude a few (< host_count) examples.

  The examples are distributed evenly across the hosts, and remaining examples
  are distributed to the hosts with the lowest id.

  Assuming a single epoch, the number of batches (e.g. for
  `create_dataset(pad_up_to_batches)`) can be computed by:

    batches = int(np.ceil(num_examples / global_batch_size))

  Args:
    split: Name of the dataset split to use or TFDS spec (e.g.
      `train[:800]+validation[:100]`). If you use the spec you must pass
        dataset_info. For specs with multiple splits each split is sharded
        independently of the other splits.
    num_examples: Deprecated - use dataset_info instead. Number of examples of
      the split.
    dataset_info: TFDS dataset info; used to get the number of examples per
      split.
    host_id: Optional, host index in [0, host_count). Defaults to
      `jax.process_index()`.
    host_count: Optional, number of hosts. Defaults to `jax.host_count`.
    drop_remainder: Deprecated - use remainder_options instead.
    remainder_options: The options to handle the remaining examples.
  """
  logging.warning(_DEPRECATE_MSG)

  if num_examples is not None:
    logging.warning(
        "`num_examples` is deprecated. Please pass `dataset_info` instead.")
  if drop_remainder is not None:
    remainder_options = (
        RemainderOptions.DROP
        if drop_remainder else RemainderOptions.BALANCE_ON_PROCESSES)
    logging.warning(
        "`drop_remainder` is deprecated. Please pass `remainder_options` "
        "instead. `remainder_options` is reset with %s.", remainder_options)
  if dataset_info is None:
    if any(special in split for special in ["[", "]", "+"]):
      raise ValueError(
          f"Sharding split {split} requires passing `dataset_info`.")
  if host_id is None:
    host_id = jax.process_index()
  if host_count is None:
    host_count = jax.process_count()
  if host_id < 0 or host_id >= host_count or host_count < 1:
    raise ValueError(
        f"Invalid combination of host_id ({host_id}) and host_count "
        f"({host_count}).")

  if _use_split_info:
    if dataset_info is None:
      split_infos = {
          split: tfds.core.SplitInfo(
              name=split,
              shard_lengths=[num_examples],
              num_bytes=0,
          ),
      }
    else:
      split_infos = dataset_info.splits
  else:
    if dataset_info is None:
      split_infos = {split: num_examples}
    else:
      split_infos = {k: v.num_examples for k, v in dataset_info.splits.items()}

  read_instruction = tfds.core.ReadInstruction.from_spec(split)
  sharded_read_instructions = []
  for ri in read_instruction.to_absolute(split_infos):
    sharded_read_instructions.append(
        _shard_read_instruction(
            ri,
            split_infos=split_infos,
            host_id=host_id,
            host_count=host_count,
            remainder_options=remainder_options))
  return functools.reduce(operator.add, sharded_read_instructions)


def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
                                     preprocess_fn: Callable[[Features],
                                                             Features], *,
                                     rng: jnp.ndarray) -> tf.data.Dataset:
  """Maps `ds` using the preprocess_fn and a deterministic RNG per example.

  Args:
    ds: Dataset containing Python dictionary with the features. The 'rng'
      feature should not exist.
    preprocess_fn: Preprocessing function that takes a Python dictionary of
      tensors and returns a Python dictionary of tensors. The function should be
      convertible into a TF graph.
    rng: Base RNG to use. Per example RNGs will be derived from this by folding
      in the example index.

  Returns:
    The dataset mapped by the `preprocess_fn`.
  """

  def _fn(example_index: int, features: Features) -> Features:
    example_index = tf.cast(example_index, tf.int32)
    features["rng"] = tf.random.experimental.stateless_fold_in(
        tf.cast(rng, tf.int64), example_index)
    processed = preprocess_fn(features)
    if isinstance(processed, dict) and "rng" in processed:
      del processed["rng"]
    return processed

  return ds.enumerate().map(_fn, num_parallel_calls=AUTOTUNE)


def pad_dataset(dataset: tf.data.Dataset,
                *,
                batch_dims: Sequence[int],
                pad_up_to_batches: Optional[int] = None,
                cardinality: Optional[int] = None):
  """Adds padding to a dataset.

  Args:
    dataset: The dataset to be padded.
    batch_dims: List of size of batch dimensions. Multiple batch dimension can
      be used to provide inputs for multiple devices. E.g.
      [jax.local_device_count(), batch_size // jax.device_count()].
    pad_up_to_batches: Set this option to process the entire dataset. When set,
      then the dataset is first padded to the specified number of batches. A new
      feature called "mask" is added to every batch. This feature is set to
      `True` for every example that comes from `dataset_builder`, and to `False`
      for every example that is padded to get to the specified number of
      batches. Note that the specified `dataset_builder` and `split` must result
      in at least `pad_up_to_batches` (possibly partial) batches. If `None`,
      derives from `batch_dims` and `cardinality` such that `pad_up_to_batches *
      batch_dims == cardinality`. Note that `cardinality` is what you pass in,
      not necessarily the original full dataset size if you decide to shard it
      per host.
    cardinality: Number of examples in the dataset. Only needed when the
      cardinality cannot be retrieved via `ds.cardinalty()` (e.g. because of
      using `ds.filter()`).

  Returns:
    The padded dataset, with the added feature "mask" that is set to `True` for
    examples from the original `dataset` and to `False` for padded examples.
  """
  if not isinstance(dataset.element_spec, dict):
    raise ValueError("The dataset must have dictionary elements.")
  if cardinality is None:
    cardinality = dataset.cardinality()
    if cardinality == tf.data.UNKNOWN_CARDINALITY:
      raise ValueError(
          "Cannot determine dataset cardinality. This can happen when you use "
          "a `.filter()` on the dataset. Please provide the cardinality as an "
          "argument to `create_dataset()`.")
  if "mask" in dataset.element_spec:
    raise ValueError("Dataset already contains a feature named \"mask\".")
  if pad_up_to_batches is None:
    pad_up_to_batches = int(np.ceil(cardinality / np.prod(batch_dims)))

  filler_element = tf.nest.map_structure(
      lambda spec: tf.zeros(spec.shape, spec.dtype)[None], dataset.element_spec)
  filler_element["mask"] = [False]
  filler_dataset = tf.data.Dataset.from_tensor_slices(filler_element)

  dataset = dataset.map(
      lambda features: dict(mask=True, **features), num_parallel_calls=AUTOTUNE)
  padding = pad_up_to_batches * np.prod(batch_dims) - int(cardinality)
  assert padding >= 0, (
      f"Invalid padding={padding} (batch_dims={batch_dims}, cardinality="
      f"{cardinality}, pad_up_to_batches={pad_up_to_batches})")
  return dataset.concatenate(filler_dataset.repeat(padding))


def create_dataset(dataset_builder: DatasetBuilder,
                   *,
                   split: Union[str, tfds.core.ReadInstruction],
                   batch_dims: Sequence[int] = (),
                   rng: Union[None, jnp.ndarray, tf.Tensor] = None,
                   filter_fn: Optional[Callable[[Features], bool]] = None,
                   preprocess_fn: Optional[Callable[[Features],
                                                    Features]] = None,
                   decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
                   cache: bool = False,
                   num_epochs: Optional[int] = None,
                   shuffle: bool = True,
                   shuffle_buffer_size: int = 10_000,
                   prefetch_size: int = 4,
                   pad_up_to_batches: Optional[Union[int, str]] = None,
                   cardinality: Optional[int] = None,
                   drop_remainder: bool = True) -> tf.data.Dataset:
  """Creates standard input pipeline (shuffle, preprocess, batch).

  Args:
    dataset_builder: Dataset builder object with a as_dataset() method. E.g.
      instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
    split: Specifies which split of the data to load. Passed on to
      `tfds.DatasetBuilder.as_dataset()`. See also the
      [split API guide](https://www.tensorflow.org/datasets/splits). In a multi
        host setup, this parameter can conveniently be generated by the function
        `get_read_instruction_for_host()`.
    batch_dims: List of size of batch dimensions. Multiple batch dimension can
      be used to provide inputs for multiple devices. E.g.
      [jax.local_device_count(), batch_size // jax.device_count()].
    rng: A jax.random.PRNG key or a tf.Tensor for TF stateless seeds to use of
      seeding shuffle operations and preprocessing ops. Must be set if
      shuffling.
    filter_fn: Optional function to filter the decoded examples. This happens
      before the preprocessing.
    preprocess_fn: Function for preprocessing individual examples (which should
      be Python dictionary of tensors).
    decoders: Optional dictionary of decoder passed to as_dataset.
    cache: Whether to cache the unprocessed dataset in memory.
    num_epochs: Number of epochs for which to repeat the dataset. None to repeat
      forever.
    shuffle: Whether to shuffle the dataset (both on file and example level).
    shuffle_buffer_size: Number of examples in the shuffle buffer.
    prefetch_size: The number of elements in the final dataset to prefetch in
      the background. This should be a small (say <10) positive integer or
      tf.data.experimental.AUTOTUNE.
    pad_up_to_batches: Set this option to process the entire dataset. - If set
      with an integer, the dataset is first padded to the specified number of
      batches. A new feature called "mask" is added to every batch. This feature
      is set to `True` for every example that comes from `dataset_builder`, and
      to `False` for every example that is padded. Note that the specified
      `dataset_builder` and `split` must result in at least `pad_up_to_batches`
      (possibly partial) batches. - If set with "auto", derives from
      `batch_dims` and `cardinality` such that `pad_up_to_batches * batch_dims
      == cardinality`. - If `None`, the dataset won't be padded.
    cardinality: Number of examples in the dataset. Only needed when
      `pad_up_to_batches` is specified and the cardinality cannot be retrieved
      via `ds.cardinalty()` (e.g. because of `ds.filter()`).
    drop_remainder: Whether to drop remainders when batching.

  Returns:
    The dataset with preprocessed and batched examples.
  """
  rng_available = rng is not None
  if not rng_available and shuffle:
    raise ValueError("Please set 'rng' when shuffling.")
  if rng_available:
    if isinstance(rng, tf.Tensor):
      rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)]
    else:
      rngs = list(jax.random.key_data(jax.random.split(rng, 3)))
  else:
    rngs = 3 * [[None, None]]

  dataset_options = tf.data.Options()
  dataset_options.experimental_optimization.map_parallelization = True
  dataset_options.threading.private_threadpool_size = 48
  dataset_options.threading.max_intra_op_parallelism = 1

  read_config = tfds.ReadConfig(
      shuffle_seed=rngs.pop()[0], options=dataset_options)
  ds = dataset_builder.as_dataset(
      split=split,
      shuffle_files=shuffle,
      read_config=read_config,
      decoders=decoders)

  if filter_fn is not None:
    ds = ds.filter(filter_fn)

  if cache:
    ds = ds.cache()

  if shuffle:
    ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0])
  ds = ds.repeat(num_epochs)

  if preprocess_fn is not None:
    if rng_available:
      ds = _preprocess_with_per_example_rng(ds, preprocess_fn, rng=rngs.pop())
    else:
      ds = ds.map(preprocess_fn, num_parallel_calls=AUTOTUNE)

  if pad_up_to_batches is not None:
    assert isinstance(pad_up_to_batches, int) or pad_up_to_batches == "auto"
    ds = pad_dataset(
        ds,
        batch_dims=batch_dims,
        pad_up_to_batches=(None if pad_up_to_batches == "auto" else
                           pad_up_to_batches),
        cardinality=cardinality)

  if batch_dims:
    for batch_size in reversed(batch_dims):
      ds = ds.batch(batch_size, drop_remainder=drop_remainder)

  return ds.prefetch(prefetch_size)


StrOrReadInstruction = Union[str, tfds.core.ReadInstruction]


def create_distributed_dataset(
    dataset_builder,
    *,
    split: Union[StrOrReadInstruction, Callable[[int, int],
                                                StrOrReadInstruction]],
    global_batch_size: int,
    strategy: tf.distribute.Strategy,
    rng: Optional[tf.Tensor] = None,
    filter_fn: Optional[Callable[[Features], bool]] = None,
    preprocess_fn: Optional[Callable[[Features], Features]] = None,
    decoders: Optional[Dict[str, tfds.decode.Decoder]] = None,
    cache: bool = False,
    num_epochs: Optional[int] = None,
    shuffle: bool = True,
    shuffle_buffer_size: int = 10_000,
    prefetch_size: int = 4,
    pad_up_to_batches: Optional[int] = None,
    cardinality: Optional[int] = None,
    drop_remainder: bool = True) -> tf.data.Dataset:
  """Creates standard input pipeline (shuffle, preprocess, batch).

  Args:
    dataset_builder: Dataset builder object with a as_dataset() method. E.g.
      instance of `tfds.core.DatasetBuilder` as returned by `tfds.builder(...)`.
    split: Split name to use, will be passed to as_dataset(). To read different
      data chunks on different replicas pass a callable that accepts the host_id
      and host_count and returns a split name.
    global_batch_size: Global batch size for all input pipelines together.
    strategy: Distribution strategy for distributing the dataset.
    rng: A tf.Tensor with a stateless random key to seed shuffle operations and
      preprocessing ops.
    filter_fn: Optional function to filter the decoded examples. This happens
      before the preprocessing.
    preprocess_fn: Function for preprocessing individual examples (which should
      be Python dictionary of tensors)
    decoders: Optional dictionary of decoder passed to as_dataset.
    cache: Whether to cache the unprocessed dataset in memory.
    num_epochs: Number of epochs for which to repeat the dataset. None to repeat
      forever.
    shuffle: Whether the shuffle the dataset (both on the file and example
      level).
    shuffle_buffer_size: Number of examples in the shuffle buffer.
    prefetch_size: The number of elements in the final dataset to prefetch in
      the background. This should be a small (say <10) positive integer or
      tf.data.experimental.AUTOTUNE.
    pad_up_to_batches: Set this option to process the entire dataset. When set,
      then the dataset is first padded to the specified number of batches. A new
      feature called "mask" is added to every batch. This feature is set to
      `True` for every example that comes from `dataset_builder`, and to `False`
      for every example that is padded to get to the specified number of
      batches. Note that the specified `dataset_builder` and `split` must
      provide at least `pad_up_to_batches` (possibly partial) batches.
    cardinality: Number of examples in the dataset. Only needed when
      `pad_up_to_batches` is specified and the cardinality cannot be retrieved
      via `ds.cardinalty()` (e.g. because of `ds.filter()`).
    drop_remainder: Whether to drop remainders when batching.

  Returns:
    The dataset with preprocessed and batched examples.
  """

  def dataset_fn(input_context: tf.distribute.InputContext):
    """Returns the dataset for a single worker."""
    logging.info("dataset_fn(input_context=%s)", input_context)

    if rng is None:
      local_rng = None
    else:
      local_rng = tf.random.experimental.stateless_fold_in(
          rng, input_context.input_pipeline_id)

    if callable(split):
      local_split = split(input_context.input_pipeline_id,
                          input_context.num_input_pipelines)
    else:
      local_split = split

    per_replica_batch_size = input_context.get_per_replica_batch_size(
        global_batch_size)

    return create_dataset(
        dataset_builder=dataset_builder,
        split=local_split,
        batch_dims=[per_replica_batch_size],
        rng=local_rng,
        filter_fn=filter_fn,
        preprocess_fn=preprocess_fn,
        decoders=decoders,
        cache=cache,
        num_epochs=num_epochs,
        shuffle=shuffle,
        shuffle_buffer_size=shuffle_buffer_size,
        prefetch_size=prefetch_size,
        pad_up_to_batches=pad_up_to_batches,
        cardinality=cardinality,
        drop_remainder=drop_remainder)

  return strategy.distribute_datasets_from_function(dataset_fn)


================================================
FILE: clu/deterministic_data_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the deterministic_data module."""
import dataclasses
import itertools
import math

from typing import Dict
from unittest import mock

from absl.testing import parameterized
from clu import deterministic_data
import jax
from packaging import version
import tensorflow as tf
import tensorflow_datasets as tfds

_use_split_info = version.parse("4.4.0") < version.parse(
    tfds.version.__version__)


@dataclasses.dataclass
class MyDatasetBuilder:

  name2len: Dict[str, int]  # Number of examples per split.

  def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool,
                 read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset:
    del shuffle_files, read_config, decoders
    if _use_split_info:
      split_infos = {
          k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
          for k, v in self.name2len.items()
      }
      instructions = split.to_absolute(split_infos)
    else:
      instructions = split.to_absolute(self.name2len)
    assert len(instructions) == 1
    from_ = instructions[0].from_ or 0
    to = instructions[0].to or self.name2len[instructions[0].splitname]
    return tf.data.Dataset.range(from_, to).map(lambda i: {"index": i})


@dataclasses.dataclass
class FakeDatasetInfo:
  train_size: int = 9
  test_size: int = 8

  @property
  def splits(self):
    return {
        "train": tfds.core.SplitInfo("train", [self.train_size], 0),
        "test": tfds.core.SplitInfo("test", [self.test_size], 0)
    }


class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
  """Tests for deterministic_data module."""

  @parameterized.parameters(
      (9, 0, 1, True, "test[0:9]"),
      (9, 0, 2, True, "test[0:4]"),
      (9, 1, 2, True, "test[4:8]"),  # Last example gets dropped.
      (9, 0, 3, True, "test[0:3]"),
      (9, 1, 3, True, "test[3:6]"),
      (9, 2, 3, True, "test[6:9]"),
      (9, 0, 1, False, "test[0:9]"),
      (9, 0, 2, False, "test[0:5]"),  # First host gets an extra example.
      (9, 1, 2, False, "test[5:9]"),
      (8, 0, 3, False, "test[0:3]"),  # First 2 hosts get 1 example each.
      (8, 1, 3, False, "test[3:6]"),
      (8, 2, 3, False, "test[6:8]"),
  )
  def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
                                                    host_id: int,
                                                    host_count: int,
                                                    drop_remainder: bool,
                                                    expected_spec: str):
    expected = tfds.core.ReadInstruction.from_spec(expected_spec)
    actual = deterministic_data.get_read_instruction_for_host(
        "test",
        num_examples,
        host_id=host_id,
        host_count=host_count,
        drop_remainder=drop_remainder)
    if _use_split_info:
      split_infos = {
          "test": tfds.core.SplitInfo(
              name="test",
              shard_lengths=[9],
              num_bytes=0,
          )}
    else:
      split_infos = {"test": 9}
    self.assertEqual(
        expected.to_absolute(split_infos), actual.to_absolute(split_infos))

  @parameterized.parameters(
      # host_id, host_count, drop_remainder, spec, exected_spec_for_host
      # train split has 9 examples.
      (0, 1, True, "train", "train[0:9]"),
      (0, 2, True, "train", "train[0:4]"),
      (1, 2, True, "train", "train[4:8]"),  # Last example gets dropped.
      (0, 3, True, "train", "train[0:3]"),
      (1, 3, True, "train", "train[3:6]"),
      (2, 3, True, "train", "train[6:9]"),
      (0, 1, False, "train", "train[0:9]"),
      (0, 2, False, "train", "train[0:5]"),  # First host gets an extra example.
      (1, 2, False, "train", "train[5:9]"),
      # test split has 8 examples.
      (0, 3, False, "test", "test[0:3]"),  # First 2 hosts get 1 example each.
      (1, 3, False, "test", "test[3:6]"),
      (2, 3, False, "test", "test[6:8]"),
      # Subsplits.
      (0, 2, True, "train[:50%]", "train[0:2]"),
      (1, 2, True, "train[:50%]", "train[2:4]"),
      (0, 2, True, "train[3:7]", "train[3:5]"),
      (1, 2, True, "train[3:7]", "train[5:7]"),
      (0, 2, True, "train[3:8]", "train[3:5]"),  # Last example gets dropped.
      (1, 2, True, "train[3:8]", "train[5:7]"),
      # 2 splits.
      (0, 2, True, "train[3:7]+test", "train[3:5]+test[0:4]"),
      (1, 2, True, "train[3:7]+test", "train[5:7]+test[4:8]"),
      # First host gets an extra example.
      (0, 2, False, "train[3:8]+test[:5]", "train[3:6]+test[0:3]"),
      (1, 2, False, "train[3:8]+test[:5]", "train[6:8]+test[3:5]"),
  )
  def test_get_read_instruction_for_host(self, host_id: int, host_count: int,
                                         drop_remainder: bool, spec: str,
                                         expected_spec_for_host: str):

    actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
        spec,
        dataset_info=FakeDatasetInfo(),
        host_id=host_id,
        host_count=host_count,
        drop_remainder=drop_remainder)
    expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
        expected_spec_for_host)
    self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))

  @parameterized.parameters(
      # host_id, host_count, balance_remainder, spec, exected_spec_for_host
      # test split has 10 examples.
      (0, 1, True, "test", "test[0:10]"),
      (0, 1, False, "test", "test[0:10]"),
      (0, 4, True, "test", "test[0:3]"),
      (1, 4, True, "test", "test[3:6]"),
      (2, 4, True, "test", "test[6:8]"),
      (3, 4, True, "test", "test[8:10]"),
      (0, 4, False, "test", "test[0:4]"),
      (1, 4, False, "test", "test[4:6]"),
      (2, 4, False, "test", "test[6:8]"),
      (3, 4, False, "test", "test[8:10]"),
  )
  def test_get_read_instruction_balance_remainder(self, host_id: int,
                                                  host_count: int,
                                                  balance_remainder: bool,
                                                  spec: str,
                                                  expected_spec_for_host: str):
    actual_spec_for_host = deterministic_data.get_read_instruction_for_host(
        spec,
        dataset_info=FakeDatasetInfo(test_size=10),
        host_id=host_id,
        host_count=host_count,
        remainder_options=deterministic_data.RemainderOptions
        .BALANCE_ON_PROCESSES if balance_remainder else
        deterministic_data.RemainderOptions.ON_FIRST_PROCESS)
    expected_spec_for_host = tfds.core.ReadInstruction.from_spec(
        expected_spec_for_host)
    self.assertEqual(str(actual_spec_for_host), str(expected_spec_for_host))

  @parameterized.parameters(
      (0, 0),  # No hosts.
      (1, 1),  # Only one host (host_id is zero-based.
      (-1, 1),  # Negative host_id.
      (5, 2),  # host_id bigger than number of hosts.
  )
  def test_get_read_instruction_for_host_fails(self, host_id: int,
                                               host_count: int):
    with self.assertRaises(ValueError):
      deterministic_data.get_read_instruction_for_host(
          "test", 11, host_id=host_id, host_count=host_count)

  def test_preprocess_with_per_example_rng(self):

    def preprocess_fn(features):
      features["b"] = tf.random.stateless_uniform([], features["rng"])
      return features

    rng = jax.random.PRNGKey(42)
    ds_in = tf.data.Dataset.from_tensor_slices({"a": [37.2, 31.2, 39.0]})
    ds_out = deterministic_data._preprocess_with_per_example_rng(
        ds_in, preprocess_fn, rng=rng)
    self.assertAllClose([
        {
            "a": 37.2,
            "b": 0.79542184
        },
        {
            "a": 31.2,
            "b": 0.45482683
        },
        {
            "a": 39.0,
            "b": 0.85335636
        },
    ], list(ds_out))

  @parameterized.parameters(*itertools.product([2, "auto"], [True, False]))
  def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
    dataset_builder = mock.Mock()
    dataset = tf.data.Dataset.from_tensor_slices(
        dict(x=tf.ones((12, 10)), y=tf.ones(12)))
    dataset_builder.as_dataset.return_value = dataset
    batch_dims = (2, 5)
    ds = deterministic_data.create_dataset(
        dataset_builder,
        split="(ignored)",
        batch_dims=batch_dims,
        num_epochs=1,
        shuffle=False,
        pad_up_to_batches=pad_up_to_batches,
        cardinality=12 if cardinality else None,
    )
    ds_iter = iter(ds)
    self.assertAllClose(
        dict(
            x=tf.ones((2, 5, 10)),
            y=tf.ones((2, 5)),
            mask=tf.ones((2, 5), bool),
        ), next(ds_iter))
    self.assertAllClose(
        dict(
            x=tf.reshape(
                tf.concat([tf.ones(
                    (2, 10)), tf.zeros((8, 10))], axis=0), (2, 5, 10)),
            y=tf.reshape(tf.concat([tf.ones(2), tf.zeros(8)], axis=0), (2, 5)),
            mask=tf.reshape(
                tf.concat(
                    [tf.ones(2, bool), tf.zeros(8, bool)], axis=0), (2, 5)),
        ), next(ds_iter))
    with self.assertRaises(StopIteration):
      next(ds_iter)

  def test_create_dataset_padding_raises_error_cardinality(self):
    dataset_builder = mock.Mock()
    dataset = tf.data.Dataset.from_tensor_slices(
        dict(x=tf.ones((12, 10)), y=tf.ones(12)))
    dataset = dataset.filter(lambda x: True)
    dataset_builder.as_dataset.return_value = dataset
    batch_dims = (2, 5)
    with self.assertRaisesRegex(
        ValueError,
        r"^Cannot determine dataset cardinality."):
      deterministic_data.create_dataset(
          dataset_builder,
          split="(ignored)",
          batch_dims=batch_dims,
          num_epochs=1,
          shuffle=False,
          pad_up_to_batches=2,
          cardinality=None,
      )

  def test_pad_dataset(self):
    dataset = tf.data.Dataset.from_tensor_slices(
        dict(x=tf.ones((12, 10)), y=tf.ones(12)))
    padded_dataset = deterministic_data.pad_dataset(
        dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
    self.assertAllClose(
        dict(
            x=tf.concat([tf.ones(
                (12, 10)), tf.zeros((8, 10))], axis=0),
            y=tf.concat([tf.ones(12), tf.zeros(8)], axis=0),
            mask=tf.concat(
                [tf.ones(12, bool), tf.zeros(8, bool)], axis=0)),
        next(iter(padded_dataset.batch(20))))

  def test_pad_nested_dataset(self):
    dataset = tf.data.Dataset.from_tensor_slices(
        {"x": {"z": (tf.ones((12, 10)), tf.ones(12))},
         "y": tf.ones((12, 4))})

    def expected(*dims):
      return tf.concat([tf.ones((12,) + dims), tf.zeros((8,) + dims)], axis=0)

    padded_dataset = deterministic_data.pad_dataset(
        dataset, batch_dims=[20], pad_up_to_batches=2, cardinality=12)
    self.assertAllClose(
        {"x": {"z": (expected(10), expected())},
         "y": expected(4),
         "mask": tf.concat([tf.ones(12, bool), tf.zeros(8, bool)], axis=0)},
        next(iter(padded_dataset.batch(20))))

  @parameterized.parameters(*itertools.product(range(20), range(1, 4)))
  def test_same_cardinality_on_all_hosts(self, num_examples: int,
                                         host_count: int):
    builder = MyDatasetBuilder({"train": num_examples})
    cardinalities = []
    for host_id in range(host_count):
      split = deterministic_data.get_read_instruction_for_host(
          split="train",
          num_examples=num_examples,
          host_id=host_id,
          host_count=host_count,
          drop_remainder=True)
      ds = deterministic_data.create_dataset(
          builder, split=split, batch_dims=[2], shuffle=False, num_epochs=1)
      cardinalities.append(ds.cardinality().numpy().item())
    self.assertLen(set(cardinalities), 1)

  @parameterized.parameters(*itertools.product(range(20), range(1, 4)))
  def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,
                                                  host_count: int):
    builder = MyDatasetBuilder({"train": num_examples})
    # All hosts should have the same number of batches.
    batch_size = 2
    pad_up_to_batches = int(math.ceil(num_examples / (batch_size * host_count)))
    assert pad_up_to_batches * batch_size * host_count >= num_examples
    cardinalities = []
    for host_id in range(host_count):
      split = deterministic_data.get_read_instruction_for_host(
          split="train",
          num_examples=num_examples,
          host_id=host_id,
          host_count=host_count,
          drop_remainder=False)
      ds = deterministic_data.create_dataset(
          builder,
          split=split,
          batch_dims=[batch_size],
          shuffle=False,
          num_epochs=1,
          pad_up_to_batches=pad_up_to_batches)
      cardinalities.append(ds.cardinality().numpy().item())
    self.assertLen(set(cardinalities), 1)


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/internal/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



================================================
FILE: clu/internal/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Small utilities by CLU libraries."""

import contextlib
import sys
import time
from typing import Any, List, Mapping, Tuple, Union

from absl import logging

import jax.numpy as jnp
import numpy as np
import wrapt


@contextlib.contextmanager
def log_activity(activity_name: str):
  """Logs `activity_name` and timing information (or exception)."""
  t0 = time.time()
  logging.info("%s ...", activity_name)
  try:
    yield
  finally:
    dt = time.time() - t0
    exc, *_ = sys.exc_info()
    if exc is not None:
      logging.exception("%s FAILED after %.2fs with %s.", activity_name, dt,
                        exc.__name__)
    else:
      logging.info("%s finished after %.2fs.", activity_name, dt)



def logged_with(activity_name: str):
  """Returns a decorator wrapping a function with `log_activity()`."""
  @wrapt.decorator
  def decorator(wrapped, instance, args, kwargs):
    del instance  # not used
    with log_activity(activity_name):
      return wrapped(*args, **kwargs)
  return decorator


def check_param(value, *, ndim=None, dtype=jnp.float32):
  """Raises a `ValueError` if `value` does not match ndim/dtype.

  Args:
    value: Value to be tested.
    ndim: Expected dimensions.
    dtype: Expected dtype.

  Raises:
    A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
    is not an instance of `jnp.ndarray`.
  """
  if not isinstance(value, (np.ndarray, jnp.ndarray)):
    raise ValueError(f"Expected np.array or jnp.array, got type={type(value)}")
  if ndim is not None and value.ndim != ndim:
    raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
  if dtype is not None and value.dtype != dtype:
    raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}")


def flatten_dict(
    d: Mapping[str, Any], prefix: Tuple[str, ...] = ()
) -> List[Tuple[str, Union[int, float, str]]]:
  """Returns a sequence of flattened (k, v) pairs for tfsummary.hparams().

  Args:
    d: A dict-like object that has an `.item()` method.
    prefix: Prefix to add to keys in `d`.

  Returns:
    Sequence of (k, v) pairs where k is the flattened key with individual
    subkeys separated by dots. `None` values are replaced by the empty string.
  """
  ret = []
  for k, v in d.items():
    # Note `ml_collections.ConfigDict` is not (yet) a `Mapping`.
    if isinstance(v, Mapping) or hasattr(v, "items"):
      ret += flatten_dict(v, prefix + (k,))
    elif isinstance(v, (list, tuple)):
      ret += flatten_dict({str(idx): value for idx, value in enumerate(v)},
                          prefix + (k,))
    else:
      ret.append((".".join(prefix + (k,)), v if v is not None else ""))
  return ret


================================================
FILE: clu/internal/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest
from clu.internal import utils
import jax.numpy as jnp
import ml_collections


class TestError(BaseException):
  __test__ = False
  pass


class HelpersTest(absltest.TestCase):

  def test_log_activity(
      self,
  ):
    with self.assertLogs() as logs:
      with utils.log_activity("test_activity"):
        pass
    self.assertLen(logs.output, 2)
    self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
    self.assertRegex(logs.output[1],
                     r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

  def test_log_activity_fails(
      self,
  ):
    with self.assertRaises(TestError):  # pylint: disable=g-error-prone-assert-raises, line-too-long
      with self.assertLogs() as logs:
        with utils.log_activity("test_activity"):
          raise TestError()
    self.assertLen(logs.output, 2)
    self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
    self.assertRegex(logs.output[1],
                     r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

  def test_logged_with(self):

    @utils.logged_with("test_activity")
    def test():
      pass

    with self.assertLogs() as logs:
      test()
    self.assertLen(logs.output, 2)
    self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
    self.assertRegex(logs.output[1],
                     r"^INFO:absl:test_activity finished after \d+.\d\ds.$")

  def test_logged_with_fails(self):

    @utils.logged_with("test_activity")
    def test():
      raise TestError()

    with self.assertRaises(TestError):  # pylint: disable=g-error-prone-assert-raises, line-too-long
      with self.assertLogs() as logs:
        test()
    self.assertLen(logs.output, 2)
    self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
    self.assertRegex(logs.output[1],
                     r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")

  def test_check_param(self):
    a = jnp.array(0.)
    with self.assertRaisesRegex(ValueError, r"^Expected np.array or jnp.array"):
      utils.check_param(None, ndim=1)
    with self.assertRaisesRegex(ValueError, r"^Expected ndim"):
      utils.check_param(a, ndim=1)
    with self.assertRaisesRegex(ValueError, r"^Expected dtype"):
      utils.check_param(a, ndim=0, dtype=jnp.int32)
    utils.check_param(a, ndim=0)  # should work
    utils.check_param(a, ndim=0, dtype=jnp.float32)  # should also work

  def test_flatten_dict(self):
    self.assertEqual(
        utils.flatten_dict(
            ml_collections.ConfigDict({
                "x": 1,
                "y": None,
                "z": ml_collections.ConfigDict({
                    "a": "bc",
                })
            })), [("x", 1), ("y", ""), ("z.a", "bc")])


if __name__ == "__main__":
  absltest.main()


================================================
FILE: clu/metric_writers/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Metric writers write ML model outputs during model training and evaluation.

This module introduces the MetricWriter interface. MetricWriters allow users
to write out metrics about ML models during training and evaluation (e.g. loss,
accuracy).
There is a MetricWriter implementation for each back end (e.g. TensorFlow
summaries) and classes that work on top other MetricWriter to
write to multiple writes at once or write asynchronously.

Note: The current interface might not contain write() methods for all possible
data types. We are open for extending the interface to other data types
(e.g. audio).

Usage:
  writer = MyMetricWriterImplementation()
  # Before training.
  writer.write_hparams({"learning_rate": 0.001, "batch_size": 64})
  # Start training loop.
  for step in range(num_train_steps):
    loss = train_step()
    if step % 50 == 0:
      writer.write_scalars(step, {"loss": loss})
      accuracy = evaluate()
      writer.write_scalars(step, {"accuracy": accuracy})
  # Make sure all values were written.
  writer.flush()  # or use metric_writers.ensure_flushes() context.
"""

# pylint: disable=unused-import
# pylint: disable=g-importing-member


from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.async_writer import ensure_flushes
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.utils import create_default_writer
from clu.metric_writers.utils import write_values

# TODO(b/200953513): Migrate away from logging imports (on module level)
#                    to logging the actual usage. See b/200953513.



================================================
FILE: clu/metric_writers/async_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter that writes metrics in a separate thread.

- The order of the write calls is preserved.
- Users need to all `flush()` or use the `ensure_flushes()` context to make sure
  that all metrics have been written.
- Errors while writing in the background thread will be re-raised in the main
  thread on the next write_*() call.
"""

from collections.abc import Mapping, Sequence
import contextlib
from typing import Any, Optional

from clu import asynclib

from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import wrapt

Array = interface.Array
Scalar = interface.Scalar


@wrapt.decorator
def _wrap_exceptions(wrapped, instance, args, kwargs):
  del instance
  try:
    return wrapped(*args, **kwargs)
  except asynclib.AsyncError as e:
    raise asynclib.AsyncError(
        "Consider re-running the code without AsyncWriter (e.g. creating a "
        "writer using "
        "`clu.metric_writers.create_default_writer(asynchronous=False)`)"
    ) from e


class AsyncWriter(interface.MetricWriter):
  """MetricWriter that performs write operations in a separate thread.

  All write operations will be executed in a background thread. If an exceptions
  occurs in the background thread it will be raised on the main thread on the
  call of one of the write_* methods.

  Use num_workers > 1 at your own risk, if the underlying writer is not
  thread-safe or does not expect out-of-order events, this can cause problems.
  If num_workers is None then the ThreadPool will use `os.cpu_count()`
  processes.
  """

  def __init__(self,
               writer: interface.MetricWriter,
               *,
               num_workers: Optional[int] = 1):
    super().__init__()
    self._writer = writer
    # By default, we have a thread pool with a single worker to ensure that
    # calls to the function are run in order (but in a background thread).
    self._num_workers = num_workers
    self._pool = asynclib.Pool(
        thread_name_prefix="AsyncWriter", max_workers=num_workers)


  @_wrap_exceptions
  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    self._pool(self._writer.write_summaries)(
        step=step, values=values, metadata=metadata)

  @_wrap_exceptions
  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    self._pool(self._writer.write_scalars)(step=step, scalars=scalars)

  @_wrap_exceptions
  def write_images(self, step: int, images: Mapping[str, Array]):
    self._pool(self._writer.write_images)(step=step, images=images)

  @_wrap_exceptions
  def write_videos(self, step: int, videos: Mapping[str, Array]):
    self._pool(self._writer.write_videos)(step=step, videos=videos)

  @_wrap_exceptions
  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    self._pool(self._writer.write_audios)(
        step=step, audios=audios, sample_rate=sample_rate)

  @_wrap_exceptions
  def write_texts(self, step: int, texts: Mapping[str, str]):
    self._pool(self._writer.write_texts)(step=step, texts=texts)

  @_wrap_exceptions
  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    self._pool(self._writer.write_histograms)(
        step=step, arrays=arrays, num_buckets=num_buckets)

  @_wrap_exceptions
  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    self._pool(self._writer.write_pointcloud)(
        step=step,
        point_clouds=point_clouds,
        point_colors=point_colors,
        configs=configs,
    )

  @_wrap_exceptions
  def write_hparams(self, hparams: Mapping[str, Any]):
    self._pool(self._writer.write_hparams)(hparams=hparams)

  def flush(self):
    try:
      self._pool.join()
    finally:
      self._writer.flush()

  def close(self):
    try:
      self.flush()
    finally:
      self._writer.close()


class AsyncMultiWriter(multi_writer.MultiWriter):
  """AsyncMultiWriter writes to multiple writes in a separate thread."""

  def __init__(self,
               writers: Sequence[interface.MetricWriter],
               *,
               num_workers: Optional[int] = 1):
    super().__init__([AsyncWriter(w, num_workers=num_workers) for w in writers])


@contextlib.contextmanager
def ensure_flushes(*writers: interface.MetricWriter):
  """Context manager which ensures that one or more writers are flushed."""
  try:
    # The caller should not need to use the yielded value, but we yield
    # the first writer to stay backwards compatible for a single writer.
    yield writers[0]
  finally:
    for writer in writers:
      writer.flush()


================================================
FILE: clu/metric_writers/async_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for AsyncWriter."""

import time
from unittest import mock

from clu import asynclib
from clu.metric_writers import async_writer
from clu.metric_writers import interface
import numpy as np
import tensorflow as tf


class AsyncWriterTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.sync_writer = mock.create_autospec(interface.MetricWriter)
    self.writer = async_writer.AsyncWriter(self.sync_writer)

  def test_write_summaries_async(self):
    self.writer.write_summaries(
        11,
        {"a": np.eye(3, dtype=np.uint8),
         "b": np.eye(2, dtype=np.float32)},
        {"a": np.ones((2, 3)).tobytes()})
    self.writer.flush()
    self.sync_writer.write_summaries.assert_called_with(
        step=11,
        values={"a": mock.ANY, "b": mock.ANY},
        metadata={"a": mock.ANY})

  def test_write_scalars_async(self):
    self.writer.write_scalars(0, {"a": 3, "b": 0.15})
    self.writer.write_scalars(2, {"a": 5, "b": 0.007})
    self.writer.flush()
    self.sync_writer.write_scalars.assert_has_calls([
        mock.call(step=0, scalars={
            "a": 3,
            "b": 0.15
        }),
        mock.call(step=2, scalars={
            "a": 5,
            "b": 0.007
        })
    ])

  def test_write_images(self):
    images = np.zeros((2, 28, 28, 3))
    self.writer.write_images(4, {"input_images": images})
    self.writer.flush()
    self.sync_writer.write_images.assert_called_with(4,
                                                     {"input_images": mock.ANY})

  def test_write_videos(self):
    videos = np.zeros((2, 4, 28, 28, 3))
    self.writer.write_videos(4, {"input_videos": videos})
    self.writer.flush()
    self.sync_writer.write_videos.assert_called_with(4,
                                                     {"input_videos": mock.ANY})

  def test_write_pointcloud(self):
    point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
    point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
    config = {
        "material": "PointCloudMaterial",
        "size": 0.09,
    }
    self.writer.write_pointcloud(
        step=0,
        point_clouds={"pcd": point_clouds},
        point_colors={"pcd": point_colors},
        configs={"config": config},
    )
    self.writer.flush()
    self.sync_writer.write_pointcloud.assert_called_with(
        step=0,
        point_clouds={"pcd": mock.ANY},
        point_colors={"pcd": mock.ANY},
        configs={"config": mock.ANY},
    )

  def test_write_texts(self):
    self.writer.write_texts(4, {"samples": "bla"})
    self.writer.flush()
    self.sync_writer.write_texts.assert_called_with(4, {"samples": "bla"})

  def test_ensure_flushes(self):
    with async_writer.ensure_flushes(self.writer) as writer:
      writer.write_scalars(0, {"a": 3, "b": 0.15})
      writer.write_scalars(2, {"a": 5, "b": 0.007})
    self.sync_writer.write_scalars.assert_has_calls([
        mock.call(step=0, scalars={
            "a": 3,
            "b": 0.15
        }),
        mock.call(step=2, scalars={
            "a": 5,
            "b": 0.007
        })
    ])
    self.sync_writer.flush.assert_called_once()

  def test_ensure_flushes_with_multiple_writers(self):
    sync_writer1 = mock.create_autospec(interface.MetricWriter)
    writer1 = async_writer.AsyncWriter(sync_writer1)
    sync_writer2 = mock.create_autospec(interface.MetricWriter)
    writer2 = async_writer.AsyncWriter(sync_writer2)

    with async_writer.ensure_flushes(writer1, writer2):
      writer1.write_scalars(0, {"a": 3, "b": 0.15})
      writer2.write_scalars(2, {"a": 5, "b": 0.007})

    sync_writer1.write_scalars.assert_has_calls(
        [mock.call(step=0, scalars={
            "a": 3,
            "b": 0.15
        })])

    sync_writer2.write_scalars.assert_has_calls(
        [mock.call(step=2, scalars={
            "a": 5,
            "b": 0.007
        })])

    sync_writer1.flush.assert_called_once()
    sync_writer2.flush.assert_called_once()

  def test_flush_before_close(self):
    self.writer.close()
    self.sync_writer.flush.assert_called()
    self.sync_writer.close.assert_called()

  def test_reraises_exception(self):
    self.sync_writer.write_scalars.side_effect = ValueError("foo")
    self.writer.write_scalars(0, {"a": 3, "b": 0.15})
    time.sleep(0.1)
    with self.assertRaisesRegex(asynclib.AsyncError, "Consider re-running"):
      self.writer.write_scalars(2, {"a": 5, "b": 0.007})


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/metric_writers/interface.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Library for unify reporting model metrics across various logging formats.

This library provides a MetricWriter for each logging format (SummyWriter,
LoggingWriter, etc.) and composing MetricWriter to add support for asynchronous
logging or writing to multiple formats.
"""

import abc
from collections.abc import Mapping
from typing import Any, Optional, Union

import jax.numpy as jnp
import numpy as np

Array = Union[np.ndarray, jnp.ndarray]
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]


class MetricWriter(abc.ABC):
  """MetricWriter inferface."""

  @abc.abstractmethod
  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    """Saves an arbitrary tensor summary.

    Useful when working with custom plugins or constructing a summary directly.

    Args:
      step: Step at which the scalar values occurred.
      values: Mapping from tensor keys to tensors.
      metadata: Optional SummaryMetadata, as a proto or serialized bytes.
                Note that markdown formatting is rendered by tensorboard.
    """

  @abc.abstractmethod
  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    """Write scalar values for the step.

    Consecutive calls to this method can provide different sets of scalars.
    Repeated writes for the same metric at the same step are not allowed.

    Args:
      step: Step at which the scalar values occurred.
      scalars: Mapping from metric name to value.
    """

  @abc.abstractmethod
  def write_images(self, step: int, images: Mapping[str, Array]):
    """Write images for the step.

    Consecutive calls to this method can provide different sets of images.
    Repeated writes for the same image key at the same step are not allowed.

    Warning: Not all MetricWriter implementation support writing images!

    Args:
      step: Step at which the images occurred.
      images: Mapping from image key to images. Images should have the shape [N,
        H, W, C] or [H, W, C], where H is the height, W is the width and C the
        number of channels (1 or 3). N is the number of images that will be
        written. Image dimensions can differ between different image keys but
        not between different steps for the same image key.
    """

  @abc.abstractmethod
  def write_videos(self, step: int, videos: Mapping[str, Array]):
    """Write videos for the step.

    Warning: Logging only.
    Not all MetricWriter implementation support writing videos!

    Consecutive calls to this method can provide different sets of videos.
    Repeated writes for the same video key at the same step are not allowed.


    Args:
      step: Step at which the videos occurred.
      videos: Mapping from video key to videos. videos should have the shape
        [N, T, H, W, C] or [T, H, W, C], where T is time, H is the height,
        W is the width and C the number of channels (1 or 3). N is the number
        of videos that will be written. Video dimensions can differ between
        different video keys but not between different steps for the same
        video key.
    """

  @abc.abstractmethod
  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    """Write audios for the step.

    Consecutive calls to this method can provide different sets of audios.
    Repeated writes for the same audio key at the same step are not allowed.

    Warning: Not all MetricWriter implementation support writing audios!

    Args:
      step: Step at which the audios occurred.
      audios: Mapping from audio key to audios. Audios should have the shape
        [N, T, C], where T is the time length and C the number of channels
        (1 - mono, 2 - stereo, >= 3 - surround; not all writers support any
        number of channels). N is the number of audios that will be written.
        Audio dimensions can differ between different audio keys but not between
        different steps for the same audio key. Values should be floating-point
        values in [-1, +1].
      sample_rate: Sample rate for the audios.
    """

  @abc.abstractmethod
  def write_texts(self, step: int, texts: Mapping[str, str]):
    """Writes text snippets for the step.

    Warning: Not all MetricWriter implementation support writing text!

    Args:
      step: Step at which the text snippets occurred.
      texts: Mapping from name to text snippet.
    """

  @abc.abstractmethod
  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    """Writes histograms for the step.

    Consecutive calls to this method can provide different sets of scalars.
    Repeated writes for the same metric at the same step are not allowed.

    Warning: Not all MetricWriter implementation support writing histograms!

    Args:
      step: Step at which the arrays were generated.
      arrays: Mapping from name to arrays to summarize.
      num_buckets: Number of buckets used to create the histogram of the arrays.
        The default number of buckets depends on the particular implementation
        of the MetricWriter.
    """

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    """Writes point cloud summaries.

    Args:
      step: Step at which the point cloud was generated.
      point_clouds: Mapping from point clouds key to point cloud of shape [N, 3]
        array of point coordinates.
      point_colors: Mapping from point colors key to [N, 3] array of point
        colors.
      configs: A dictionary of configuration options for the point cloud.
    """
    raise NotImplementedError()

  @abc.abstractmethod
  def write_hparams(self, hparams: Mapping[str, Any]):
    """Write hyper parameters.

    Do not call twice.

    Args:
      hparams: Flat mapping from hyper parameter name to value.
    """

  @abc.abstractmethod
  def flush(self):
    """Tells the MetricWriter to write out any cached values."""

  @abc.abstractmethod
  def close(self):
    """Flushes and closes the MetricWriter.

    Calling any method on MetricWriter after MetricWriter.close()
    is undefined behavior.
    """


================================================
FILE: clu/metric_writers/logging_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter that writes all values to INFO log."""

from collections.abc import Mapping
from typing import Any, Optional

from absl import logging
from clu.metric_writers import interface
import numpy as np

Array = interface.Array
Scalar = interface.Scalar


class LoggingWriter(interface.MetricWriter):
  """MetricWriter that writes all values to INFO log."""

  def __init__(self, collection: Optional[str] = None):
    if collection:
      self._collection_str = f" collection={collection}"
    else:
      self._collection_str = ""

  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    logging.info("[%d]%s Got raw tensors: %s.", step, self._collection_str,
                 {k: v.shape for k, v in values.items()})

  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    values = [
        f"{k}={v:.6g}" if isinstance(v, float) else f"{k}={v}"
        for k, v in sorted(scalars.items())
    ]
    logging.info("[%d]%s %s", step, self._collection_str, ", ".join(values))

  def write_images(self, step: int, images: Mapping[str, Array]):
    logging.info("[%d]%s Got images: %s.", step, self._collection_str,
                 {k: v.shape for k, v in images.items()})

  def write_videos(self, step: int, videos: Mapping[str, Array]):
    logging.info("[%d]%s Got videos: %s.", step, self._collection_str,
                 {k: v.shape for k, v in videos.items()})

  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    logging.info("[%d]%s Got audios: %s.", step, self._collection_str,
                 {k: v.shape for k, v in audios.items()})

  def write_texts(self, step: int, texts: Mapping[str, str]):
    logging.info("[%d]%s Got texts: %s.", step, self._collection_str, texts)

  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    num_buckets = num_buckets or {}
    for key, value in arrays.items():
      histo, bins = _compute_histogram_as_tf(
          np.asarray(value), num_buckets=num_buckets.get(key))
      if histo is not None:
        logging.info("[%d]%s Histogram for %r = {%s}", step,
                     self._collection_str, key,
                     _get_histogram_as_string(histo, bins))

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Any] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    logging.info(
        "[%d]%s Got point clouds: %s, point_colors: %s, configs: %s.",
        step,
        self._collection_str,
        {k: v.shape for k, v in point_clouds.items()},
        (
            {k: v.shape for k, v in point_colors.items()}
            if point_colors is not None
            else None
        ),
        configs,
    )

  def write_hparams(self, hparams: Mapping[str, Any]):
    logging.info("[Hyperparameters]%s %s", self._collection_str, hparams)

  def flush(self):
    logging.flush()

  def close(self):
    self.flush()


def _compute_histogram_as_tf(
    array: np.ndarray,
    num_buckets: Optional[int] = None
) -> tuple[Optional[np.ndarray], Optional[np.ndarray]]:
  """Compute the histogram of the input array as TF would do.

  Args:
    array: Input data. The histogram is computed over the flattened array.
    num_buckets: The number of equal-width bins used to create the histogram.

  Returns:
    histo: A numpy array with the values of the histogram.
    bins: A numpy array with the bin edges (its length is length(histo)+1).

    If the histogram cannot be built because the array is empty, returns
    (None, None).
  """
  # See DEFAULT_BUCKET_COUNT in tensorboard/plugins/histogram/summary_v2.py
  num_buckets = num_buckets or 30
  if num_buckets < 2:
    logging.log_first_n(logging.WARNING,
                        "num_buckets was automatically changed from %d to 2", 1,
                        num_buckets)
    num_buckets = 2

  if array.size == 0:
    return None, None

  range_max = np.max(array)
  range_min = np.min(array)
  if np.isclose(range_max, range_min, rtol=1e-5, atol=1e-8):
    histo = np.asarray([array.size], dtype=np.int64)
    bins = np.asarray([range_max - 0.5, range_max + 0.5], dtype=np.float64)
  else:
    histo, bins = np.histogram(
        array, bins=num_buckets, range=(range_min, range_max))
    bins = np.asarray(bins, dtype=np.float64)

  return histo, bins


def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):
  # First items are right-open (i.e. [a, b)).
  items = [
      f"[{bins[i]:.3g}, {bins[i+1]:.3g}): {count}"
      for i, count in enumerate(histo[:-1])
  ]
  # Last item is right-closed (i.e. [a, b]).
  items.append(f"[{bins[-2]:.3g}, {bins[-1]:.3g}]: {histo[-1]}")
  return ", ".join(items)


================================================
FILE: clu/metric_writers/logging_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the LoggingWriter."""

from clu.metric_writers import logging_writer
import numpy as np
import tensorflow as tf


class LoggingWriterTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.writer = logging_writer.LoggingWriter()

  def test_write_scalars(self):
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_scalars(0, {"a": 3, "b": 0.15})
      self.writer.write_scalars(2, {"a": 0.0000005, "b": 0.007})
    self.assertEqual(
        logs.output,
        ["INFO:absl:[0] a=3, b=0.15", "INFO:absl:[2] a=5e-07, b=0.007"])

  def test_write_images(self):
    images = np.zeros((2, 28, 28, 3))
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_images(4, {"input_images": images})
    self.assertEqual(
        logs.output,
        ["INFO:absl:[4] Got images: {'input_images': (2, 28, 28, 3)}."])

  def test_write_videos(self):
    videos = np.zeros((2, 4, 28, 28, 3))
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_videos(4, {"input_videos": videos})
    self.assertEqual(
        logs.output,
        ["INFO:absl:[4] Got videos: {'input_videos': (2, 4, 28, 28, 3)}."])

  def test_write_texts(self):
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_texts(4, {"samples": "bla"})
    self.assertEqual(
        logs.output,
        ["INFO:absl:[4] Got texts: {'samples': 'bla'}."])

  def test_write_histogram(self):
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_histograms(
          step=4,
          arrays={
              "a": np.asarray([-0.1, 0.1, 0.3]),
              "b": np.arange(31),
              "c": np.asarray([0.1, 0.1, 0.1, 0.1, 0.1]),
          },
          num_buckets={
              "a": 2,
              "c": 1
          })
    # Note: There are 31 distinct values [0, 1, ..., 30], and 30 buckets by
    # default. Last bucket gets 2 values.
    expected_histo_b = ", ".join([f"[{i}, {i + 1}): 1" for i in range(29)] +
                                 ["[29, 30]: 2"])
    self.assertEqual(logs.output, [
        "INFO:absl:[4] Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
        f"INFO:absl:[4] Histogram for 'b' = {{{expected_histo_b}}}",
        "WARNING:absl:num_buckets was automatically changed from 1 to 2",
        "INFO:absl:[4] Histogram for 'c' = {[-0.4, 0.6]: 5}",
    ])

  def test_write_pointcloud(self):
    point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
    point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
    config = {
        "material": "PointCloudMaterial",
        "size": 0.09,
    }
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_pointcloud(
          step=4,
          point_clouds={"pcd": point_clouds},
          point_colors={"pcd": point_colors},
          configs={"configs": config},
      )
    self.assertEqual(
        logs.output,
        [
            "INFO:absl:[4] Got point clouds: {'pcd': (1, 1024, 3)},"
            " point_colors: {'pcd': (1, 1024, 3)}, configs: {'configs':"
            " {'material': 'PointCloudMaterial', 'size': 0.09}}."
        ],
    )

  def test_write_hparams(self):
    with self.assertLogs(level="INFO") as logs:
      self.writer.write_hparams({"learning_rate": 0.1, "batch_size": 128})
    self.assertEqual(logs.output, [
        "INFO:absl:[Hyperparameters] {'learning_rate': 0.1, 'batch_size': 128}"
    ])

  def test_collection(self):
    writer = logging_writer.LoggingWriter(collection="train")
    with self.assertLogs(level="INFO") as logs:
      writer.write_scalars(0, {"a": 3, "b": 0.15})
      writer.write_images(4, {"input_images": np.zeros((2, 28, 28, 3))})
      writer.write_texts(4, {"samples": "bla"})
      writer.write_histograms(
          step=4,
          arrays={
              "a": np.asarray([-0.1, 0.1, 0.3]),
          },
          num_buckets={
              "a": 2,
          })
      writer.write_hparams({"learning_rate": 0.1})

    self.assertEqual(logs.output, [
        "INFO:absl:[0] collection=train a=3, b=0.15",
        "INFO:absl:[4] collection=train Got images: {'input_images': (2, 28, 28, 3)}.",
        "INFO:absl:[4] collection=train Got texts: {'samples': 'bla'}.",
        "INFO:absl:[4] collection=train Histogram for 'a' = {[-0.1, 0.1): 1, [0.1, 0.3]: 2}",
        "INFO:absl:[Hyperparameters] collection=train {'learning_rate': 0.1}",
    ])


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/metric_writers/multi_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter that writes to multiple MetricWriters."""

from collections.abc import Mapping, Sequence
from typing import Any, Optional

from clu.metric_writers import interface

Array = interface.Array
Scalar = interface.Scalar


class MultiWriter(interface.MetricWriter):
  """MetricWriter that writes to multiple writers at once."""

  def __init__(self, writers: Sequence[interface.MetricWriter]):
    self._writers = tuple(writers)

  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    for w in self._writers:
      w.write_summaries(step, values, metadata)

  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    for w in self._writers:
      w.write_scalars(step, scalars)

  def write_images(self, step: int, images: Mapping[str, Array]):
    for w in self._writers:
      w.write_images(step, images)

  def write_videos(self, step: int, videos: Mapping[str, Array]):
    for w in self._writers:
      w.write_videos(step, videos)

  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    for w in self._writers:
      w.write_audios(step, audios, sample_rate=sample_rate)

  def write_texts(self, step: int, texts: Mapping[str, str]):
    for w in self._writers:
      w.write_texts(step, texts)

  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    for w in self._writers:
      w.write_histograms(step, arrays, num_buckets)

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    for w in self._writers:
      w.write_pointcloud(
          step, point_clouds, point_colors=point_colors, configs=configs
      )

  def write_hparams(self, hparams: Mapping[str, Any]):
    for w in self._writers:
      w.write_hparams(hparams)

  def flush(self):
    for w in self._writers:
      w.flush()

  def close(self):
    for w in self._writers:
      w.close()


================================================
FILE: clu/metric_writers/multi_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for MultiWriter."""

from unittest import mock

from clu.metric_writers import interface
from clu.metric_writers import multi_writer
import numpy as np
import tensorflow as tf


class MultiWriterTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.writers = [
        mock.create_autospec(interface.MetricWriter),
        mock.create_autospec(interface.MetricWriter)
    ]
    self.writer = multi_writer.MultiWriter(self.writers)

  def test_write_scalars(self):
    self.writer.write_scalars(0, {"a": 3, "b": 0.15})
    self.writer.write_scalars(2, {"a": 5, "b": 0.007})
    self.writer.flush()
    for w in self.writers:
      w.write_scalars.assert_has_calls([
          mock.call(step=0, scalars={
              "a": 3,
              "b": 0.15
          }),
          mock.call(step=2, scalars={
              "a": 5,
              "b": 0.007
          })
      ])
      w.flush.assert_called()

  def test_write_pointcloud(self):
    point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
    point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
    config = {
        "material": "PointCloudMaterial",
        "size": 0.09,
    }
    self.writer.write_pointcloud(
        step=0,
        point_clouds={"pcd": point_clouds},
        point_colors={"pcd": point_colors},
        configs={"config": config},
    )
    self.writer.flush()
    for w in self.writers:
      w.write_pointcloud.assert_called_with(
          step=0,
          point_clouds={"pcd": point_clouds},
          point_colors={"pcd": point_colors},
          configs={"config": config},
      )
      w.flush.assert_called()


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/metric_writers/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter for writing to TF summary files."""
# pylint: disable=unused-import

from .tf.summary_writer import SummaryWriter


================================================
FILE: clu/metric_writers/tf/__init__.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Package __init__ file."""


================================================
FILE: clu/metric_writers/tf/summary_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter for writing to TF summary files.

Only works in eager mode. Does not work for Pytorch code, please use
TorchTensorboardWriter instead.
"""

from collections.abc import Mapping
from typing import Any, Optional

from absl import logging

from clu.internal import utils
from clu.metric_writers import interface
from etils import epy
import tensorflow as tf

with epy.lazy_imports():
  # pylint: disable=g-import-not-at-top
  from tensorboard.plugins.hparams import api as hparams_api
  from tensorboard.plugins.mesh import summary as mesh_summary  # pylint: disable=line-too-long
  # pylint: enable=g-import-not-at-top


Array = interface.Array
Scalar = interface.Scalar


class SummaryWriter(interface.MetricWriter):
  """MetricWriter that writes TF summary files."""

  def __init__(self, logdir: str):
    super().__init__()
    self._summary_writer = tf.summary.create_file_writer(logdir)


  def write_summaries(
      self,
      step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None,
  ):
    with self._summary_writer.as_default():
      for key, value in values.items():
        md = metadata.get(key) if metadata is not None else None
        tf.summary.write(key, value, step=step, metadata=md)

  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    with self._summary_writer.as_default():
      for key, value in scalars.items():
        tf.summary.scalar(key, value, step=step)

  def write_images(self, step: int, images: Mapping[str, Array]):
    with self._summary_writer.as_default():
      for key, value in images.items():
        if len(value.shape) == 3:
          value = value[None]
        tf.summary.image(key, value, step=step, max_outputs=value.shape[0])

  def write_videos(self, step: int, videos: Mapping[str, Array]):
    logging.log_first_n(
        logging.WARNING,
        "SummaryWriter does not support writing videos.", 1)

  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    with self._summary_writer.as_default():
      for key, value in audios.items():
        tf.summary.audio(key, value, sample_rate=sample_rate, step=step,
                         max_outputs=value.shape[0])

  def write_texts(self, step: int, texts: Mapping[str, str]):
    with self._summary_writer.as_default():
      for key, value in texts.items():
        tf.summary.text(key, value, step=step)

  def write_histograms(
      self,
      step: int,
      arrays: Mapping[str, Array],
      num_buckets: Optional[Mapping[str, int]] = None,
  ):
    with self._summary_writer.as_default():
      for key, value in arrays.items():
        buckets = None if num_buckets is None else num_buckets.get(key)
        tf.summary.histogram(key, value, step=step, buckets=buckets)

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    with self._summary_writer.as_default():
      for key, vertices in point_clouds.items():
        colors = None if point_colors is None else point_colors.get(key)
        config = None if configs is None else configs.get(key)
        mesh_summary.mesh(
            key,
            vertices=vertices,
            colors=colors,
            step=step,
            config_dict=config,
        )

  def write_hparams(self, hparams: Mapping[str, Any]):
    with self._summary_writer.as_default():
      hparams_api.hparams(dict(utils.flatten_dict(hparams)))

  def flush(self):
    self._summary_writer.flush()

  def close(self):
    self._summary_writer.close()


================================================
FILE: clu/metric_writers/tf/summary_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for SummaryWriter."""

import collections
import os

from clu.metric_writers.tf import summary_writer
import numpy as np
import tensorflow as tf

from tensorboard.plugins.hparams import plugin_data_pb2


def _load_summaries_data(logdir):
  """Loads raw summaries data from events in a logdir."""
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = collections.defaultdict(dict)
  metadata = collections.defaultdict(dict)
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        data[event.step][value.tag] = tf.make_ndarray(value.tensor)
        if value.HasField("metadata"):
          metadata[event.step][value.tag] = value.metadata.SerializeToString()
  return data, metadata


def _load_histograms_data(logdir):
  """Loads tensor summaries from events in a logdir."""
  # Note: new versions of histograms don't use the HistogramProto type, but
  # they are written as tensors representing the bounds and counts of buckets,
  # with plugin_name = "histogram".
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = {}
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        current_steps, current_tensors = data.get(value.tag, ([], []))
        data[value.tag] = (current_steps + [event.step],
                           current_tensors + [tf.make_ndarray(value.tensor)])
  return {
      tag: (np.stack(steps), np.stack(tensors))
      for tag, (steps, tensors) in data.items()
  }


def _load_scalars_data(logdir: str):
  """Loads scalar summaries from events in a logdir."""
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = collections.defaultdict(dict)
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        data[event.step][value.tag] = tf.make_ndarray(value.tensor).flat[0]

  return data


def _load_pointcloud_data(logdir: str):
  """Loads pointcloud summaries from events in a logdir."""
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = collections.defaultdict(dict)
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        if value.metadata.plugin_data.plugin_name == "mesh":
          if "config" not in value.tag:
            data[event.step][value.tag] = tf.make_ndarray(value.tensor)
          else:
            data[event.step][value.tag] = value.metadata.plugin_data.content
  return data


def _load_hparams(logdir: str):
  """Loads hparams summaries from events in a logdir."""
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  # data = collections.defaultdict(dict)
  hparams = []
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        if value.metadata.plugin_data.plugin_name == "hparams":
          hparams.append(plugin_data_pb2.HParamsPluginData.FromString(
              value.metadata.plugin_data.content))
  return hparams


class SummaryWriterTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.logdir = self.get_temp_dir()
    self.writer = summary_writer.SummaryWriter(self.logdir)

  def test_write_summaries(self):
    self.writer.write_summaries(
        11,
        {"a": np.eye(3, dtype=np.uint8),
         "b": np.eye(2, dtype=np.float32)},
        {"a": np.ones((2, 3)).tobytes()})
    self.writer.flush()
    data, metadata = _load_summaries_data(self.logdir)
    self.assertAllClose(
        data[11],
        {"a": np.eye(3, dtype=np.uint8), "b": np.eye(2, dtype=np.float32)})
    self.assertIn("a", metadata[11])

  def test_write_scalar(self):
    self.writer.write_scalars(11, {"a": 0.6, "b": 15})
    self.writer.write_scalars(20, {"a": 0.8, "b": 12})
    self.writer.flush()
    data = _load_scalars_data(self.logdir)
    self.assertAllClose(data[11], {"a": 0.6, "b": 15})
    self.assertAllClose(data[20], {"a": 0.8, "b": 12})

  def test_write_histograms(self):
    self.writer.write_histograms(
        0, {
            "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
            "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
        }, num_buckets={"a": 2, "b": 2})
    self.writer.write_histograms(
        2, {
            "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
            "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
        }, num_buckets={"a": 2, "b": 2})
    self.writer.flush()
    data = _load_histograms_data(self.logdir)
    # In the histograms, each tuple represents
    # (bucket_min, bucket_max, bucket_count), where bucket_min is inclusive and
    # bucket_max is exclusive (except the last bucket_max which is inclusive).
    expected_histograms_a = [
        # Step 0.
        [(0.1, 0.4, 3), (0.4, 0.7, 2)],
        # Step 1.
        [(-0.1, 0.2, 2), (0.2, 0.5, 3)],
    ]
    self.assertAllClose(data["a"], ([0, 2], expected_histograms_a))
    expected_histograms_b = [
        # Step 0.
        [(-0.1, 0.15, 1), (0.15, 0.4, 4)],
        # Step 1.
        [(0.0, 0.35, 4), (0.35, 0.7, 1)],
    ]
    self.assertAllClose(data["b"], ([0, 2], expected_histograms_b))

  def test_write_pointcloud(self):
    point_clouds = np.random.normal(0, 1, (1, 1024, 3)).astype(np.float32)
    point_colors = np.random.uniform(0, 1, (1, 1024, 3)).astype(np.float32)
    config = {
        "material": "PointCloudMaterial",
        "size": 0.09,
    }
    self.writer.write_pointcloud(
        step=0,
        point_clouds={"pcd": point_clouds},
        point_colors={"pcd": point_colors},
        configs={"config": config},
    )
    self.writer.flush()
    data = _load_pointcloud_data(self.logdir)
    self.assertAllClose(data[0]["pcd_VERTEX"], point_clouds)
    self.assertAllClose(data[0]["pcd_COLOR"], point_colors)

  def test_hparams(self):
    self.writer.write_hparams(dict(batch_size=512, num_epochs=90))
    hparams = _load_hparams(self.logdir)
    self.assertLen(hparams, 1)
    hparams_dict = hparams[0].session_start_info.hparams
    self.assertLen(hparams_dict, 2)
    self.assertEqual(512, hparams_dict["batch_size"].number_value)
    self.assertEqual(90, hparams_dict["num_epochs"].number_value)

  def test_hparams_nested(self):
    config = {
        "list": [1, 2],
        "tuple": (3, 4),
        "subconfig": {
            "value": "a",
            "list": [10, 20],
        },
    }
    self.writer.write_hparams(config)
    hparams = _load_hparams(self.logdir)
    self.assertLen(hparams, 1)
    hparams_dict = hparams[0].session_start_info.hparams
    self.assertLen(hparams_dict, 7)
    self.assertEqual(1, hparams_dict["list.0"].number_value)
    self.assertEqual(2, hparams_dict["list.1"].number_value)
    self.assertEqual(3, hparams_dict["tuple.0"].number_value)
    self.assertEqual(4, hparams_dict["tuple.1"].number_value)
    self.assertEqual("a", hparams_dict["subconfig.value"].string_value)
    self.assertEqual(10, hparams_dict["subconfig.list.0"].number_value)
    self.assertEqual(20, hparams_dict["subconfig.list.1"].number_value)

if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/metric_writers/torch_tensorboard_writer.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricWriter for Pytorch summary files.

Use this writer for the Pytorch-based code.

"""

from collections.abc import Mapping
from typing import Any, Optional
from absl import logging

from clu.metric_writers import interface
from torch.utils import tensorboard

Array = interface.Array
Scalar = interface.Scalar


class TorchTensorboardWriter(interface.MetricWriter):
  """MetricWriter that writes Pytorch summary files."""

  def __init__(self, logdir: str):
    super().__init__()
    self._writer = tensorboard.SummaryWriter(log_dir=logdir)


  def write_summaries(
      self, step: int,
      values: Mapping[str, Array],
      metadata: Optional[Mapping[str, Any]] = None):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorboardWriter does not support writing raw summaries.", 1)

  def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    for key, value in scalars.items():
      self._writer.add_scalar(key, value, global_step=step)

  def write_images(self, step: int, images: Mapping[str, Array]):
    for key, value in images.items():
      self._writer.add_image(key, value, global_step=step, dataformats="HWC")

  def write_videos(self, step: int, videos: Mapping[str, Array]):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorBoardWriter does not support writing videos.", 1)

  def write_audios(
      self, step: int, audios: Mapping[str, Array], *, sample_rate: int):
    for key, value in audios.items():
      self._writer.add_audio(
          key, value, global_step=step, sample_rate=sample_rate)

  def write_texts(self, step: int, texts: Mapping[str, str]):
    raise NotImplementedError(
        "TorchTensorBoardWriter does not support writing texts."
    )

  def write_histograms(self,
                       step: int,
                       arrays: Mapping[str, Array],
                       num_buckets: Optional[Mapping[str, int]] = None):
    for tag, values in arrays.items():
      bins = None if num_buckets is None else num_buckets.get(tag)
      self._writer.add_histogram(
          tag, values, global_step=step, bins="auto", max_bins=bins)

  def write_pointcloud(
      self,
      step: int,
      point_clouds: Mapping[str, Array],
      *,
      point_colors: Mapping[str, Array] | None = None,
      configs: Mapping[str, str | float | bool | None] | None = None,
  ):
    logging.log_first_n(
        logging.WARNING,
        "TorchTensorBoardWriter does not support writing point clouds.",
        1,
    )

  def write_hparams(self, hparams: Mapping[str, Any]):
    self._writer.add_hparams(hparams, {})

  def flush(self):
    self._writer.flush()

  def close(self):
    self._writer.close()


================================================
FILE: clu/metric_writers/torch_tensorboard_writer_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for TorchTensorboardWriter."""

import collections
import os
from typing import Any, Dict

from clu.metric_writers import torch_tensorboard_writer
import numpy as np
import tensorflow as tf


def _load_scalars_data(logdir: str):
  """Loads scalar summaries from events in a logdir."""
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = collections.defaultdict(dict)
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      for value in event.summary.value:
        data[event.step][value.tag] = value.simple_value

  return data


def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
  """Loads histograms summaries from events in a logdir.

  Args:
    logdir: a directory to find logs

  Returns:
    A generated histograms in a shape step -> tag -> histo.
  """
  paths = tf.io.gfile.glob(os.path.join(logdir, "events.out.tfevents.*"))
  data = {}
  for path in paths:
    for event in tf.compat.v1.train.summary_iterator(path):
      if event.step not in data:
        data[event.step] = {}
      step_data = {}
      for value in event.summary.value:
        print(" value:", value)
        step_data[value.tag] = value.histo
      data[event.step].update(step_data)

  return data


class TorchTensorboardWriterTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.logdir = self.get_temp_dir()
    self.writer = torch_tensorboard_writer.TorchTensorboardWriter(self.logdir)

  def test_write_scalar(self):
    self.writer.write_scalars(11, {"a": 0.6, "b": 15})
    self.writer.write_scalars(20, {"a": 0.8, "b": 12})
    self.writer.flush()
    data = _load_scalars_data(self.logdir)
    self.assertAllClose(data[11], {"a": 0.6, "b": 15})
    self.assertAllClose(data[20], {"a": 0.8, "b": 12})

  def test_write_histograms(self):
    self.writer.write_histograms(
        0, {
            "a": np.asarray([0.3, 0.1, 0.5, 0.7, 0.1]),
            "b": np.asarray([-0.1, 0.3, 0.2, 0.4, 0.4]),
        }, num_buckets={"a": 2, "b": 2})
    self.writer.write_histograms(
        2, {
            "a": np.asarray([0.2, 0.4, 0.5, 0.1, -0.1]),
            "b": np.asarray([0.7, 0.3, 0.2, 0.1, 0.0]),
        }, num_buckets={"a": 2, "b": 2})
    self.writer.flush()
    data = _load_histograms_data(self.logdir)
    self.assertNear(data[0]["a"].min, 0.1, 0.001)
    self.assertNear(data[0]["a"].max, 0.7, 0.001)
    self.assertNear(data[0]["b"].min, -0.1, 0.001)
    self.assertNear(data[0]["b"].max, 0.4, 0.001)
    self.assertNear(data[2]["a"].min, -0.1, 0.001)
    self.assertNear(data[2]["a"].max, 0.5, 0.001)
    self.assertNear(data[2]["b"].min, 0.0, 0.001)
    self.assertNear(data[2]["b"].max, 0.7, 0.001)


if __name__ == "__main__":
  tf.test.main()


================================================
FILE: clu/metric_writers/utils.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines a generic write interface.

The write helper accepts a MetricWriter object and a Mapping[str,
clu.metrics.Metric], and automatically writes to the appropriate typed write
method of the writer depending on the type of the metric.
"""

# pylint: disable=g-importing-member

import collections
import getpass
import os
import re
from typing import Any, List, Mapping, Optional, Tuple, Union

from absl import flags
from absl import logging
from clu import values
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from etils import epath
import jax.numpy as jnp
import numpy as np


FLAGS = flags.FLAGS


def _is_scalar(value: Any) -> bool:
  if isinstance(value, values.Scalar) or isinstance(
      value, (int, float, np.number)
  ):
    return True
  if isinstance(value, (np.ndarray, jnp.ndarray)):
    return value.ndim == 0 or value.size <= 1
  return False


def write_values(
    writer: MetricWriter,
    step: int,
    metrics: Mapping[
        str, Union[values.Value, values.ArrayType, values.ScalarType]
    ],
):
  """Writes all provided metrics.

  Allows providing a mapping of name to Value object, where each Value
  specifies a type. The appropriate write method can then be called depending
  on the type.

  Args:
    writer: MetricWriter object
    step: Step at which the arrays were generated.
    metrics: Mapping from name to clu.values.Value object.
  """
  writes = collections.defaultdict(dict)
  histogram_num_buckets = collections.defaultdict(int)
  for k, v in metrics.items():
    if isinstance(v, values.Summary):
      writes[
          (writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
      ][k] = v.value
    elif _is_scalar(v):
      if isinstance(v, values.Scalar):
        writes[(writer.write_scalars, frozenset())][k] = v.value
      else:
        writes[(writer.write_scalars, frozenset())][k] = v
    elif isinstance(v, values.Image):
      writes[(writer.write_images, frozenset())][k] = v.value
    elif isinstance(v, values.Text):
      writes[(writer.write_texts, frozenset())][k] = v.value
    elif isinstance(v, values.HyperParam):
      writes[(writer.write_hparams, frozenset())][k] = v.value
    elif isinstance(v, values.Histogram):
      writes[(writer.write_histograms, frozenset())][k] = v.value
      histogram_num_buckets[k] = v.num_buckets
    elif isinstance(v, values.Audio):
      writes[(
          writer.write_audios,
          frozenset({"sample_rate": v.sample_rate}.items()),
      )][k] = v.value
    else:
      raise ValueError("Metric: ", k, " has unsupported value: ", v)

  for (fn, extra_args), vals in writes.items():
    if fn == writer.write_histograms:
      # for write_histograms, the num_buckets arg is a Dict indexed by name
      writer.write_histograms(step, vals, num_buckets=histogram_num_buckets)
    else:
      fn(step, vals, **dict(extra_args))




def create_default_writer(
    logdir: Optional[epath.PathLike] = None,
    *,
    just_logging: bool = False,
    asynchronous: bool = True,
    collection: Optional[str] = None,
) -> MultiWriter:
  """Create the default writer for the platform.

  On most platforms this will create a MultiWriter that writes to multiple back
  ends (logging, TF summaries etc.).

  Args:
    logdir: Logging dir to use for TF summary files. If empty/None will the
      returned writer will not write TF summary files.
    just_logging: If True only use a LoggingWriter. This is useful in multi-host
      setups when only the first host should write metrics and all other hosts
      should only write to their own logs.
      default (None) will automatically determine if you # GOOGLE-INTERNAL have
    asynchronous: If True return an AsyncMultiWriter to not block when writing
      metrics.
    collection: A string which, if provided, provides an indication that the
      provided metrics should all be written to the same collection, or
      grouping.

  Returns:
    A `MetricWriter` according to the platform and arguments.
  """
  if just_logging:
    if asynchronous:
      return AsyncMultiWriter([LoggingWriter(collection=collection)])
    else:
      return MultiWriter([LoggingWriter(collection=collection)])
  writers = [LoggingWriter(collection=collection)]
  if logdir is not None:
    logdir = epath.Path(logdir)
    if collection is not None:
      logdir /= collection
    writers.append(SummaryWriter(os.fspath(logdir)))
  if asynchronous:
    return AsyncMultiWriter(writers)
  return MultiWriter(writers)


================================================
FILE: clu/metric_writers/utils_test.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for interface."""
# pylint: disable=g-importing-member

import itertools
from typing import Any
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from clu import values
from clu.metric_writers import utils
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
import clu.metrics
import flax.struct
import jax.numpy as jnp
import tensorflow as tf


@flax.struct.dataclass
class HistogramMetric(clu.metrics.Metric):
  value: jnp.ndarray
  num_buckets: int

  def compute_value(self):
    return values.Histogram(self.value, self.num_buckets)


@flax.struct.dataclass
class ImageMetric(clu.metrics.Metric):
  value: jnp.ndarray

  def compute_value(self):
    return values.Image(self.value)


@flax.struct.dataclass
class AudioMetric(clu.metrics.Metric):
  value: jnp.ndarray
  sample_rate: int

  def compute_value(self):
    return values.Audio(self.value, self.sample_rate)


@flax.struct.dataclass
class TextMetric(clu.metrics.Metric):
  value: str

  def compute_value(self):
    return values.Text(self.value)


@flax.struct.dataclass
class HyperParamMetric(clu.metrics.Metric):
  value: float

  def compute_value(self):
    return values.HyperParam(self.value)


@flax.struct.dataclass
class SummaryMetric(clu.metrics.Metric):
  value: jnp.ndarray
  metadata: Any

  def compute_value(self):
    return values.Summary(self.value, self.metadata)


def _to_summary(metrics):
  return {k: v.value for k, v in metrics.items()}


def _to_list_of_dicts(d):
  return [{k: v} for k, v in d.items()]


class ONEOF(object):
  """ONEOF(options_list) check value in options_list."""

  def __init__(self, container):
    if not hasattr(container, "__contains__"):
      raise TypeError(f"{container!r} is not a container")
    if not container:
      raise ValueError(f"{container!r} is empty")
    self._c = container

  def __eq__(self, o):
    return o in self._c

  def __ne__(self, o):
    return o not in self._c

  def __repr__(self):
    return "<ONEOF({})>".format(",".join(repr(i) for i in self._c))


class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):

  def test_write(self):
    writer = mock.Mock(spec_set=MetricWriter)
    step = 3
    num_buckets = 4
    sample_rate = 10
    scalar_metrics = {
        "loss": clu.metrics.Average.from_model_output(jnp.asarray([1, 2, 3])),
        "accuracy": clu.metrics.LastValue.from_model_output(jnp.asarray([5])),
    }
    image_metrics = {
        "image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
    }
    histogram_metrics = {
        "hist": HistogramMetric(
            value=jnp.asarray([7, 8]), num_buckets=num_buckets
        ),
        "hist2": HistogramMetric(
            value=jnp.asarray([9, 10]), num_buckets=num_buckets
        ),
    }
    audio_metrics = {
        "audio": AudioMetric(
            value=jnp.asarray([1, 5]), sample_rate=sample_rate
        ),
        "audio2": AudioMetric(
            value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
        ),
    }
    text_metrics = {
        "text": TextMetric(value="hello"),
    }
    hparam_metrics = {
        "lr": HyperParamMetric(value=0.01),
    }
    summary_metrics = {
        "summary": SummaryMetric(
            value=jnp.asarray([2, 3, 10]), metadata="some info"
        ),
        "summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
    }
    metrics = {
        **scalar_metrics,
        **image_metrics,
        **histogram_metrics,
        **audio_metrics,
        **text_metrics,
        **hparam_metrics,
        **summary_metrics,
    }
    metrics = {k: m.compute_value() for k, m in metrics.items()}
    utils.write_values(writer, step, metrics)

    writer.write_scalars.assert_called_once_with(
        step, {k: m.compute() for k, m in scalar_metrics.items()}
    )
    writer.write_images.assert_called_once_with(
        step, _to_summary(image_metrics)
    )
    writer.write_histograms.assert_called_once_with(
        step,
        _to_summary(histogram_metrics),
        num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
    )
    writer.write_audios.assert_called_with(
        step,
        ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
        sample_rate=ONEOF([sample_rate, sample_rate + 2]),
    )
    writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
    writer.write_hparams.assert_called_once_with(
        step, _to_summary(hparam_metrics)
    )
    writer.write_summaries.assert_called_with(
        step,
        ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
        metadata=ONEOF(["some info", 5]),
    )


  def test_create_default_writer_summary_writer_is_added(self):
    writer = utils.create_default_writer(
        logdir=self.get_temp_dir(), asynchronous=False
    )
    self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
    writer = utils.create_default_writer(logdir=None, asynchronous=False)
    self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))


if __name__ == "__main__":
  absltest.main()


================================================
FILE: clu/metrics.py
================================================
# Copyright 2026 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functional metric computation library.

This library defines a functional metric computation interface `Metric` that
relies on metrics accumulating intermediate values (in a possibly distributed
manner), and then computes the final metric value from these intermediate
values. Note that most metrics can be created via `Average.from_fun()`. See also
`CollectingMetric` that collects all model outputs with a given name and lends
itself to metric computation in Python.

Some common metrics, such as accuracy and loss average/standard deviation, and a
`Collection` with the same interface, are also provided.

The "model output" is a dictionary of values with unique keys that all have a
specific meaning (such as `loss`, `logits`, and `labels`) and every metric
depends on at least one such model output by name. These outputs are usually
expected to be instances of `jnp.ndarray`.

Synopsis:

  # Note: Metrics do *not* work with `from __future__ import annotations`

  from clu import metrics
  import flax
  import jax

  @flax.struct.dataclass  # required for jax.tree_*
  class MyCollection(metrics.Collection):
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")
    loss_std: metrics.Std.from_output("loss")

  @jax.pmap
  def eval_step(variables, metrics, inputs, labels):
    loss, logits = get_loss_and_logits(variables, inputs, labels)
    return metrics.merge(MyCollection.gather_from_model_output(
        loss=loss, logits=logits, labels=labels))

  def evaluate(variables_p, test_ds):
    metrics = MyCollection.empty()
    for inputs, labels in test_ds:
      metrics = eval_step(variables_p, metrics, inputs, labels)
    return metrics.unreplicate().compute()
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
import inspect
from typing import Any, TypeVar, Protocol

from absl import logging

from clu.internal import utils
import clu.values
import flax
import jax
import jax.numpy as jnp
import numpy as np

Array = jax.Array
ArrayLike = jax.typing.ArrayLike


class FromFunCallable(Protocol):
  """The type of functions that can be passed to `Metrics.from_fun()`."""

  def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
    """Returns the argument/arguments passed to the base from_model_output()."""


# TODO(b/200953513): Migrate away from logging imports (on module level)
#                    to logging the actual usage. See b/200953513.



def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
  """Raises a `ValueError` if shapes of `a` and `b` don't match."""
  if a.shape != b.shape:
    raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")


M = TypeVar("M", bound="Metric")


class Metric:
  """Interface for computing metrics from intermediate values.

  Refer to `Collection` for computing multiple metrics at the same time.

  Synopsis:

    import jax.numpy as jnp
    import flax

    @flax.struct.dataclass
    class Average(Metric):
      total: jnp.ndarray
      count: jnp.ndarray

      @classmethod
      def from_model_output(cls, value: jnp.ndarray, **_) -> Metric:
        return cls(total=value.sum(), count=np.prod(value.shape))

      def merge(self, other: Metric) -> Metric:
        return type(self)(
          total=self.total + other.total,
          count=self.count + other.count,
        )

      def compute(self):
        return self.total / self.count

    average = None
    for value in range(data):
      update = Average.from_model_output(value)
      average = update if average is None else average.merge(update)
    print(average.compute())
  """

  @classmethod
  def from_model_output(cls: type[M], *args, **kwargs) -> M:
    """Creates a `Metric` from model outputs."""
    raise NotImplementedError("Must override from_model_output()")

  def merge(self: M, other: M) -> M:
    """Returns `Metric` that is the accumulation of `self` and `other`.

    Args:
      other: A `Metric` whose intermediate values should be accumulated onto the
        values of `self`. Note that in a distributed setting, `other` will
        typicall
Download .txt
gitextract_g9bj2g2j/

├── .github/
│   └── workflows/
│       ├── build.yml
│       └── python-publish.yml
├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── clu/
│   ├── __init__.py
│   ├── asynclib.py
│   ├── asynclib_test.py
│   ├── checkpoint.py
│   ├── checkpoint_test.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── dataset_iterator.py
│   │   └── dataset_iterator_test.py
│   ├── deterministic_data.py
│   ├── deterministic_data_test.py
│   ├── internal/
│   │   ├── __init__.py
│   │   ├── utils.py
│   │   └── utils_test.py
│   ├── metric_writers/
│   │   ├── __init__.py
│   │   ├── async_writer.py
│   │   ├── async_writer_test.py
│   │   ├── interface.py
│   │   ├── logging_writer.py
│   │   ├── logging_writer_test.py
│   │   ├── multi_writer.py
│   │   ├── multi_writer_test.py
│   │   ├── summary_writer.py
│   │   ├── tf/
│   │   │   ├── __init__.py
│   │   │   ├── summary_writer.py
│   │   │   └── summary_writer_test.py
│   │   ├── torch_tensorboard_writer.py
│   │   ├── torch_tensorboard_writer_test.py
│   │   ├── utils.py
│   │   └── utils_test.py
│   ├── metrics.py
│   ├── metrics_test.py
│   ├── parameter_overview.py
│   ├── parameter_overview_test.py
│   ├── periodic_actions.py
│   ├── periodic_actions_test.py
│   ├── platform/
│   │   ├── __init__.py
│   │   ├── interface.py
│   │   └── local.py
│   ├── preprocess_spec.py
│   ├── preprocess_spec_test.py
│   ├── profiler.py
│   ├── run_pytest.google.sh
│   └── values.py
├── clu_synopsis.ipynb
└── setup.py
Download .txt
SYMBOL INDEX (535 symbols across 36 files)

FILE: clu/asynclib.py
  class AsyncError (line 27) | class AsyncError(Exception):
  class Pool (line 31) | class Pool:
    method __init__ (line 50) | def __init__(self, thread_name_prefix: str = "",
    method _reraise (line 69) | def _reraise(self) -> None:
    method close (line 76) | def close(self) -> None:
    method join (line 81) | def join(self) -> None:
    method queue_length (line 98) | def queue_length(self) -> int:
    method has_errors (line 103) | def has_errors(self) -> bool:
    method clear_errors (line 107) | def clear_errors(self) -> List[Exception]:
    method __call__ (line 113) | def __call__(self, fn: Callable):  # pylint: disable=g-bare-generic

FILE: clu/asynclib_test.py
  class AsyncWriterTest (line 23) | class AsyncWriterTest(absltest.TestCase):
    method test_async_execution (line 25) | def test_async_execution(self):
    method test_reraise (line 39) | def test_reraise(self):
    method test_queue_length (line 65) | def test_queue_length(self, executor_mock):
    method test_flush (line 95) | def test_flush(self, executor_mock):

FILE: clu/checkpoint.py
  function safe_normpath (line 77) | def safe_normpath(path: str) -> str:
  function load_state_dict (line 83) | def load_state_dict(base_directory) -> Dict[str, Any]:
  class CheckpointInfo (line 106) | class CheckpointInfo(
    method initialize (line 113) | def initialize(cls, base_directory, checkpoint_name: str) -> "Checkpoi...
    method from_path (line 118) | def from_path(cls, checkpoint: str) -> "CheckpointInfo":
    method increment (line 134) | def increment(self) -> "CheckpointInfo":
    method __str__ (line 138) | def __str__(self):
  class Checkpoint (line 143) | class Checkpoint:
    method __init__ (line 162) | def __init__(self,
    method get_latest_checkpoint_to_restore_from (line 194) | def get_latest_checkpoint_to_restore_from(self):
    method latest_checkpoint (line 207) | def latest_checkpoint(self) -> Optional[str]:
    method current_checkpoint (line 220) | def current_checkpoint(self) -> Optional[str]:
    method _flax_path (line 240) | def _flax_path(self, checkpoint: str) -> str:
    method _next_checkpoint (line 243) | def _next_checkpoint(self, checkpoint: Optional[str]) -> str:
    method _checkpoint_number (line 249) | def _checkpoint_number(self, checkpoint: Optional[str]) -> Optional[int]:
    method _delete_future_checkpoints (line 254) | def _delete_future_checkpoints(self):
    method save (line 273) | def save(self, state) -> str:
    method restore_or_initialize (line 330) | def restore_or_initialize(self, state: T) -> T:
    method restore_dict (line 350) | def restore_dict(self, checkpoint: Optional[str] = None) -> Dict[str, ...
    method _checkpoint_or_latest (line 371) | def _checkpoint_or_latest(self, checkpoint: Optional[str] = None) -> str:
    method load_state (line 378) | def load_state(self,
    method restore (line 410) | def restore(self,
  class MultihostCheckpoint (line 448) | class MultihostCheckpoint(Checkpoint):
    method __init__ (line 463) | def __init__(self,
    method get_latest_checkpoint_to_restore_from (line 503) | def get_latest_checkpoint_to_restore_from(self) -> Optional[str]:

FILE: clu/checkpoint_test.py
  function _make_dataset (line 26) | def _make_dataset():
  class TrainState (line 34) | class TrainState:
  class TrainStateExtended (line 39) | class TrainStateExtended:
  class NotTrainState (line 44) | class NotTrainState:
  function _checkpoint_number (line 48) | def _checkpoint_number(path):
  class CheckpointTest (line 54) | class CheckpointTest(tf.test.TestCase):
    method test_safe_normpath (line 56) | def test_safe_normpath(self):
    method test_initialize_mkdir (line 63) | def test_initialize_mkdir(self):
    method test_restores_flax_state (line 75) | def test_restores_flax_state(self):
    method test_load_state_dict (line 102) | def test_load_state_dict(self):
    method test_fails_when_restoring_subset (line 114) | def test_fails_when_restoring_subset(self):
    method test_fails_when_restoring_superset (line 125) | def test_fails_when_restoring_superset(self):
    method test_restores_tf_state (line 136) | def test_restores_tf_state(self):
    method test_restore_flax_alone (line 172) | def test_restore_flax_alone(self):
    method test_restore_dict (line 185) | def test_restore_dict(self):
    method test_ignores_incomplete_checkpoint (line 213) | def test_ignores_incomplete_checkpoint(self):
    method test_max_to_keep (line 249) | def test_max_to_keep(self):
    method test_checkpoint_name (line 262) | def test_checkpoint_name(self):
    method test_fails_if_not_registered (line 269) | def test_fails_if_not_registered(self):
    method test_overwrite (line 276) | def test_overwrite(self):
  class MultihostCheckpoint (line 312) | class MultihostCheckpoint(tf.test.TestCase):
    method test_initialize_mkdir (line 315) | def test_initialize_mkdir(self, process_index_mock):
    method test_synchronize_multiple_hosts (line 328) | def test_synchronize_multiple_hosts(self, process_index_mock):
    method test_preemption (line 360) | def test_preemption(self):

FILE: clu/data/dataset_iterator.py
  class ArraySpec (line 52) | class ArraySpec:
    method __repr__ (line 57) | def __repr__(self):
    method __str__ (line 60) | def __str__(self):
  class DatasetIterator (line 77) | class DatasetIterator(collections.abc.Iterator):  # pytype: disable=igno...
    method get_next (line 92) | def get_next(self) -> Element:
    method reset (line 99) | def reset(self):
    method element_spec (line 105) | def element_spec(self) -> ElementSpec:
    method save (line 109) | def save(self, filename: epath.Path):
    method restore (line 119) | def restore(self, filename: epath.Path):
    method load (line 129) | def load(self, filename: epath.Path):
  class TfDatasetIterator (line 134) | class TfDatasetIterator(DatasetIterator):
    method __init__ (line 137) | def __init__(self, dataset, *, checkpoint: bool):
    method get_next (line 174) | def get_next(self) -> Element:
    method __next__ (line 177) | def __next__(self) -> Element:
    method reset (line 180) | def reset(self):
    method element_spec (line 185) | def element_spec(self) -> ElementSpec:
    method save (line 202) | def save(self, filename: epath.Path):
    method restore (line 206) | def restore(self, filename: epath.Path):
  class PeekableDatasetIterator (line 211) | class PeekableDatasetIterator(DatasetIterator):
    method __init__ (line 230) | def __init__(self, it: DatasetIterator):
    method __next__ (line 238) | def __next__(self) -> Element:
    method reset (line 246) | def reset(self):
    method element_spec (line 254) | def element_spec(self) -> ElementSpec:
    method peek (line 257) | def peek(self) -> Element:
    method peek_async (line 270) | def peek_async(self) -> concurrent.futures.Future[Element]:
    method save (line 286) | def save(self, filename: epath.Path):
    method restore (line 290) | def restore(self, filename: epath.Path):

FILE: clu/data/dataset_iterator_test.py
  class DatasetIteratorTest (line 28) | class DatasetIteratorTest(parameterized.TestCase, tf.test.TestCase):
    method _create_iterator (line 30) | def _create_iterator(self, start_index: int, checkpoint: bool = True):
    method test_tf_iterator (line 40) | def test_tf_iterator(self):
    method test_tf_iterator_save_and_load (line 53) | def test_tf_iterator_save_and_load(self):
    method test_tf_iterator_save_and_load_no_checkpoint (line 70) | def test_tf_iterator_save_and_load_no_checkpoint(self):
    method test_peekable_dataset_iterator (line 84) | def test_peekable_dataset_iterator(self):
    method test_peekable_dataset_iterator_async (line 92) | def test_peekable_dataset_iterator_async(self, wait: bool, peek_first:...

FILE: clu/deterministic_data.py
  class DatasetBuilder (line 82) | class DatasetBuilder(typing_extensions.Protocol):
    method as_dataset (line 85) | def as_dataset(
  class RemainderOptions (line 92) | class RemainderOptions(enum.Enum):
  function _shard_read_instruction (line 111) | def _shard_read_instruction(
  function get_read_instruction_for_host (line 175) | def get_read_instruction_for_host(
  function _preprocess_with_per_example_rng (line 272) | def _preprocess_with_per_example_rng(ds: tf.data.Dataset,
  function pad_dataset (line 303) | def pad_dataset(dataset: tf.data.Dataset,
  function create_dataset (line 362) | def create_dataset(dataset_builder: DatasetBuilder,
  function create_distributed_dataset (line 484) | def create_distributed_dataset(

FILE: clu/deterministic_data_test.py
  class MyDatasetBuilder (line 35) | class MyDatasetBuilder:
    method as_dataset (line 39) | def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: ...
  class FakeDatasetInfo (line 57) | class FakeDatasetInfo:
    method splits (line 62) | def splits(self):
  class DeterministicDataTest (line 69) | class DeterministicDataTest(tf.test.TestCase, parameterized.TestCase):
    method test_get_read_instruction_for_host_deprecated (line 86) | def test_get_read_instruction_for_host_deprecated(self, num_examples: ...
    method test_get_read_instruction_for_host (line 140) | def test_get_read_instruction_for_host(self, host_id: int, host_count:...
    method test_get_read_instruction_balance_remainder (line 168) | def test_get_read_instruction_balance_remainder(self, host_id: int,
    method test_get_read_instruction_for_host_fails (line 191) | def test_get_read_instruction_for_host_fails(self, host_id: int,
    method test_preprocess_with_per_example_rng (line 197) | def test_preprocess_with_per_example_rng(self):
    method test_create_dataset_padding (line 223) | def test_create_dataset_padding(self, pad_up_to_batches, cardinality):
    method test_create_dataset_padding_raises_error_cardinality (line 258) | def test_create_dataset_padding_raises_error_cardinality(self):
    method test_pad_dataset (line 278) | def test_pad_dataset(self):
    method test_pad_nested_dataset (line 292) | def test_pad_nested_dataset(self):
    method test_same_cardinality_on_all_hosts (line 309) | def test_same_cardinality_on_all_hosts(self, num_examples: int,
    method test_same_cardinality_on_all_hosts_with_pad (line 326) | def test_same_cardinality_on_all_hosts_with_pad(self, num_examples: int,

FILE: clu/internal/utils.py
  function log_activity (line 30) | def log_activity(activity_name: str):
  function logged_with (line 47) | def logged_with(activity_name: str):
  function check_param (line 57) | def check_param(value, *, ndim=None, dtype=jnp.float32):
  function flatten_dict (line 77) | def flatten_dict(

FILE: clu/internal/utils_test.py
  class TestError (line 23) | class TestError(BaseException):
  class HelpersTest (line 28) | class HelpersTest(absltest.TestCase):
    method test_log_activity (line 30) | def test_log_activity(
    method test_log_activity_fails (line 41) | def test_log_activity_fails(
    method test_logged_with (line 53) | def test_logged_with(self):
    method test_logged_with_fails (line 66) | def test_logged_with_fails(self):
    method test_check_param (line 80) | def test_check_param(self):
    method test_flatten_dict (line 91) | def test_flatten_dict(self):

FILE: clu/metric_writers/async_writer.py
  function _wrap_exceptions (line 39) | def _wrap_exceptions(wrapped, instance, args, kwargs):
  class AsyncWriter (line 51) | class AsyncWriter(interface.MetricWriter):
    method __init__ (line 64) | def __init__(self,
    method write_summaries (line 78) | def write_summaries(
    method write_scalars (line 86) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 90) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 94) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 98) | def write_audios(
    method write_texts (line 104) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 108) | def write_histograms(self,
    method write_pointcloud (line 116) | def write_pointcloud(
    method write_hparams (line 132) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 135) | def flush(self):
    method close (line 141) | def close(self):
  class AsyncMultiWriter (line 148) | class AsyncMultiWriter(multi_writer.MultiWriter):
    method __init__ (line 151) | def __init__(self,
  function ensure_flushes (line 159) | def ensure_flushes(*writers: interface.MetricWriter):

FILE: clu/metric_writers/async_writer_test.py
  class AsyncWriterTest (line 27) | class AsyncWriterTest(tf.test.TestCase):
    method setUp (line 29) | def setUp(self):
    method test_write_summaries_async (line 34) | def test_write_summaries_async(self):
    method test_write_scalars_async (line 46) | def test_write_scalars_async(self):
    method test_write_images (line 61) | def test_write_images(self):
    method test_write_videos (line 68) | def test_write_videos(self):
    method test_write_pointcloud (line 75) | def test_write_pointcloud(self):
    method test_write_texts (line 96) | def test_write_texts(self):
    method test_ensure_flushes (line 101) | def test_ensure_flushes(self):
    method test_ensure_flushes_with_multiple_writers (line 117) | def test_ensure_flushes_with_multiple_writers(self):
    method test_flush_before_close (line 142) | def test_flush_before_close(self):
    method test_reraises_exception (line 147) | def test_reraises_exception(self):

FILE: clu/metric_writers/interface.py
  class MetricWriter (line 33) | class MetricWriter(abc.ABC):
    method write_summaries (line 37) | def write_summaries(
    method write_scalars (line 53) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 65) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 83) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 104) | def write_audios(
    method write_texts (line 126) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 137) | def write_histograms(self,
    method write_pointcloud (line 156) | def write_pointcloud(
    method write_hparams (line 177) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 187) | def flush(self):
    method close (line 191) | def close(self):

FILE: clu/metric_writers/logging_writer.py
  class LoggingWriter (line 28) | class LoggingWriter(interface.MetricWriter):
    method __init__ (line 31) | def __init__(self, collection: Optional[str] = None):
    method write_summaries (line 37) | def write_summaries(
    method write_scalars (line 44) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 51) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 55) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 59) | def write_audios(
    method write_texts (line 64) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 67) | def write_histograms(self,
    method write_pointcloud (line 80) | def write_pointcloud(
    method write_hparams (line 101) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 104) | def flush(self):
    method close (line 107) | def close(self):
  function _compute_histogram_as_tf (line 111) | def _compute_histogram_as_tf(
  function _get_histogram_as_string (line 152) | def _get_histogram_as_string(histo: np.ndarray, bins: np.ndarray):

FILE: clu/metric_writers/logging_writer_test.py
  class LoggingWriterTest (line 22) | class LoggingWriterTest(tf.test.TestCase):
    method setUp (line 24) | def setUp(self):
    method test_write_scalars (line 28) | def test_write_scalars(self):
    method test_write_images (line 36) | def test_write_images(self):
    method test_write_videos (line 44) | def test_write_videos(self):
    method test_write_texts (line 52) | def test_write_texts(self):
    method test_write_histogram (line 59) | def test_write_histogram(self):
    method test_write_pointcloud (line 83) | def test_write_pointcloud(self):
    method test_write_hparams (line 106) | def test_write_hparams(self):
    method test_collection (line 113) | def test_collection(self):

FILE: clu/metric_writers/multi_writer.py
  class MultiWriter (line 26) | class MultiWriter(interface.MetricWriter):
    method __init__ (line 29) | def __init__(self, writers: Sequence[interface.MetricWriter]):
    method write_summaries (line 32) | def write_summaries(
    method write_scalars (line 39) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 43) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 47) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 51) | def write_audios(
    method write_texts (line 56) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 60) | def write_histograms(self,
    method write_pointcloud (line 67) | def write_pointcloud(
    method write_hparams (line 80) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 84) | def flush(self):
    method close (line 88) | def close(self):

FILE: clu/metric_writers/multi_writer_test.py
  class MultiWriterTest (line 25) | class MultiWriterTest(tf.test.TestCase):
    method setUp (line 27) | def setUp(self):
    method test_write_scalars (line 35) | def test_write_scalars(self):
    method test_write_pointcloud (line 52) | def test_write_pointcloud(self):

FILE: clu/metric_writers/tf/summary_writer.py
  class SummaryWriter (line 42) | class SummaryWriter(interface.MetricWriter):
    method __init__ (line 45) | def __init__(self, logdir: str):
    method write_summaries (line 50) | def write_summaries(
    method write_scalars (line 61) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 66) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 73) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 78) | def write_audios(
    method write_texts (line 85) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 90) | def write_histograms(
    method write_pointcloud (line 101) | def write_pointcloud(
    method write_hparams (line 121) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 125) | def flush(self):
    method close (line 128) | def close(self):

FILE: clu/metric_writers/tf/summary_writer_test.py
  function _load_summaries_data (line 27) | def _load_summaries_data(logdir):
  function _load_histograms_data (line 41) | def _load_histograms_data(logdir):
  function _load_scalars_data (line 60) | def _load_scalars_data(logdir: str):
  function _load_pointcloud_data (line 72) | def _load_pointcloud_data(logdir: str):
  function _load_hparams (line 87) | def _load_hparams(logdir: str):
  class SummaryWriterTest (line 101) | class SummaryWriterTest(tf.test.TestCase):
    method setUp (line 103) | def setUp(self):
    method test_write_summaries (line 108) | def test_write_summaries(self):
    method test_write_scalar (line 121) | def test_write_scalar(self):
    method test_write_histograms (line 129) | def test_write_histograms(self):
    method test_write_pointcloud (line 160) | def test_write_pointcloud(self):
    method test_hparams (line 178) | def test_hparams(self):
    method test_hparams_nested (line 187) | def test_hparams_nested(self):

FILE: clu/metric_writers/torch_tensorboard_writer.py
  class TorchTensorboardWriter (line 32) | class TorchTensorboardWriter(interface.MetricWriter):
    method __init__ (line 35) | def __init__(self, logdir: str):
    method write_summaries (line 40) | def write_summaries(
    method write_scalars (line 48) | def write_scalars(self, step: int, scalars: Mapping[str, Scalar]):
    method write_images (line 52) | def write_images(self, step: int, images: Mapping[str, Array]):
    method write_videos (line 56) | def write_videos(self, step: int, videos: Mapping[str, Array]):
    method write_audios (line 61) | def write_audios(
    method write_texts (line 67) | def write_texts(self, step: int, texts: Mapping[str, str]):
    method write_histograms (line 72) | def write_histograms(self,
    method write_pointcloud (line 81) | def write_pointcloud(
    method write_hparams (line 95) | def write_hparams(self, hparams: Mapping[str, Any]):
    method flush (line 98) | def flush(self):
    method close (line 101) | def close(self):

FILE: clu/metric_writers/torch_tensorboard_writer_test.py
  function _load_scalars_data (line 26) | def _load_scalars_data(logdir: str):
  function _load_histograms_data (line 38) | def _load_histograms_data(logdir: str) -> Dict[int, Dict[str, Any]]:
  class TorchTensorboardWriterTest (line 62) | class TorchTensorboardWriterTest(tf.test.TestCase):
    method setUp (line 64) | def setUp(self):
    method test_write_scalar (line 69) | def test_write_scalar(self):
    method test_write_histograms (line 77) | def test_write_histograms(self):

FILE: clu/metric_writers/utils.py
  function _is_scalar (line 46) | def _is_scalar(value: Any) -> bool:
  function write_values (line 56) | def write_values(
  function create_default_writer (line 113) | def create_default_writer(

FILE: clu/metric_writers/utils_test.py
  class HistogramMetric (line 39) | class HistogramMetric(clu.metrics.Metric):
    method compute_value (line 43) | def compute_value(self):
  class ImageMetric (line 48) | class ImageMetric(clu.metrics.Metric):
    method compute_value (line 51) | def compute_value(self):
  class AudioMetric (line 56) | class AudioMetric(clu.metrics.Metric):
    method compute_value (line 60) | def compute_value(self):
  class TextMetric (line 65) | class TextMetric(clu.metrics.Metric):
    method compute_value (line 68) | def compute_value(self):
  class HyperParamMetric (line 73) | class HyperParamMetric(clu.metrics.Metric):
    method compute_value (line 76) | def compute_value(self):
  class SummaryMetric (line 81) | class SummaryMetric(clu.metrics.Metric):
    method compute_value (line 85) | def compute_value(self):
  function _to_summary (line 89) | def _to_summary(metrics):
  function _to_list_of_dicts (line 93) | def _to_list_of_dicts(d):
  class ONEOF (line 97) | class ONEOF(object):
    method __init__ (line 100) | def __init__(self, container):
    method __eq__ (line 107) | def __eq__(self, o):
    method __ne__ (line 110) | def __ne__(self, o):
    method __repr__ (line 113) | def __repr__(self):
  class MetricWriterTest (line 117) | class MetricWriterTest(tf.test.TestCase, parameterized.TestCase):
    method test_write (line 119) | def test_write(self):
    method test_create_default_writer_summary_writer_is_added (line 198) | def test_create_default_writer_summary_writer_is_added(self):

FILE: clu/metrics.py
  class FromFunCallable (line 76) | class FromFunCallable(Protocol):
    method __call__ (line 79) | def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]:
  function _assert_same_shape (line 88) | def _assert_same_shape(a: jnp.ndarray, b: jnp.ndarray):
  class Metric (line 97) | class Metric:
    method from_model_output (line 133) | def from_model_output(cls: type[M], *args, **kwargs) -> M:
    method merge (line 137) | def merge(self: M, other: M) -> M:
    method _reduce_merge (line 160) | def _reduce_merge(self: M, other: M) -> M:
    method compute (line 163) | def compute(self) -> jnp.ndarray:
    method empty (line 168) | def empty(cls: type[M]) -> M:
    method compute_value (line 172) | def compute_value(self) -> clu.values.Value:
    method reduce (line 176) | def reduce(self: M) -> M:
    method from_fun (line 235) | def from_fun(cls, fun: FromFunCallable):  # No way to annotate return ...
    method from_output (line 314) | def from_output(cls, name: str):  # No way to annotate return type
  class CollectingMetric (line 358) | class CollectingMetric(Metric):
    method empty (line 416) | def empty(cls) -> CollectingMetric:
    method merge (line 419) | def merge(self, other: CollectingMetric) -> CollectingMetric:
    method reduce (line 434) | def reduce(self) -> CollectingMetric:
    method compute (line 440) | def compute(self):  # No return type annotation, so subclasses can ove...
    method from_outputs (line 444) | def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]:
  class _ReductionCounter (line 465) | class _ReductionCounter(Metric):
    method empty (line 471) | def empty(cls) -> _ReductionCounter:
    method merge (line 474) | def merge(self, other: _ReductionCounter) -> _ReductionCounter:
  function _check_reduction_counter_ndim (line 478) | def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
  class Collection (line 490) | class Collection:
    method create (line 512) | def create(cls, **metrics: type[Metric]) -> type[Collection]:
    method create_collection (line 538) | def create_collection(cls, **metrics: Metric) -> Collection:
    method empty (line 566) | def empty(cls: type[C]) -> C:
    method _from_model_output (line 576) | def _from_model_output(cls: type[C], **kwargs) -> C:
    method single_from_model_output (line 587) | def single_from_model_output(cls: type[C], **kwargs) -> C:
    method gather_from_model_output (line 602) | def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs...
    method merge (line 617) | def merge(self: C, other: C) -> C:
    method reduce (line 624) | def reduce(self: C) -> C:
    method compute (line 656) | def compute(self) -> dict[str, jnp.ndarray]:
    method compute_values (line 665) | def compute_values(self) -> dict[str, clu.values.Value]:
    method unreplicate (line 674) | def unreplicate(self: C) -> C:
  class LastValue (line 692) | class LastValue(Metric):
    method __init__ (line 707) | def __init__(  # pytype: disable=missing-parameter  # jnp-array
    method empty (line 743) | def empty(cls) -> LastValue:
    method from_model_output (line 747) | def from_model_output(
    method merge (line 757) | def merge(self, other: LastValue) -> LastValue:
    method _reduce_merge (line 761) | def _reduce_merge(self, other: LastValue) -> LastValue:
    method value (line 770) | def value(self) -> jnp.ndarray:
    method compute (line 775) | def compute(self) -> Any:
  function _broadcast_masks (line 779) | def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None):
  class Average (line 801) | class Average(Metric):
    method empty (line 820) | def empty(cls) -> Average:
    method from_model_output (line 824) | def from_model_output(
    method merge (line 837) | def merge(self, other: Average) -> Average:
    method compute (line 844) | def compute(self) -> Any:
  class Std (line 849) | class Std(Metric):
    method empty (line 861) | def empty(cls) -> Std:
    method from_model_output (line 868) | def from_model_output(
    method merge (line 882) | def merge(self, other: Std) -> Std:
    method compute (line 890) | def compute(self) -> Any:
  class Accuracy (line 906) | class Accuracy(Average):
    method from_model_output (line 916) | def from_model_output(

FILE: clu/metrics_test.py
  class CollectingMetricAccuracy (line 32) | class CollectingMetricAccuracy(
    method compute (line 35) | def compute(self):
  class Collection (line 45) | class Collection(metrics.Collection):
  class CollectionMixed (line 51) | class CollectionMixed(metrics.Collection):
  class MetricsTest (line 56) | class MetricsTest(parameterized.TestCase):
    method setUp (line 58) | def setUp(self):
    method make_compute_metric (line 117) | def make_compute_metric(self, metric_class, reduce, jit=True):
    method test_metric_last_value_reduce (line 151) | def test_metric_last_value_reduce(self):
    method test_metric_last_value (line 177) | def test_metric_last_value(self):
    method test_metric_last_value_legacy_kwarg_value (line 192) | def test_metric_last_value_legacy_kwarg_value(self):
    method test_metric_last_value_tree_manipulation (line 198) | def test_metric_last_value_tree_manipulation(self):
    method test_from_fun_with_single_output (line 213) | def test_from_fun_with_single_output(self):
    method test_from_fun_with_mapping_output (line 229) | def test_from_fun_with_mapping_output(self):
    method test_average_masked (line 262) | def test_average_masked(self, values, mask, expected_result):
    method test_merge_asserts_shape (line 284) | def test_merge_asserts_shape(self, metric_cls):
    method test_accuracy (line 296) | def test_accuracy(self, reduce):
    method test_last_value_asserts_shape (line 301) | def test_last_value_asserts_shape(self):
    method test_loss_average (line 313) | def test_loss_average(self, reduce):
    method test_loss_std (line 328) | def test_loss_std(self, reduce):
    method test_collection_create (line 341) | def test_collection_create(self):
    method test_collection_create_custom_mask (line 350) | def test_collection_create_custom_mask(self):
    method test_collection_create_collection (line 374) | def test_collection_create_collection(self):
    method test_collection_single (line 394) | def test_collection_single(self, masked):
    method test_collection_gather (line 418) | def test_collection_gather(self, masked, all_gather_mock):
    method test_collection_gather_pmap (line 442) | def test_collection_gather_pmap(self, masked):
    method test_collection_asserts_replication (line 455) | def test_collection_asserts_replication(self):
    method test_collecting_metric (line 466) | def test_collecting_metric(self):
    method test_collecting_metric_reduce (line 480) | def test_collecting_metric_reduce(self):
    method test_collecting_metric_async (line 486) | def test_collecting_metric_async(self):
    method test_collecting_metric_tracer (line 505) | def test_collecting_metric_tracer(self):
    method test_collection_mixed_async (line 512) | def test_collection_mixed_async(self):
    method test_metric_empty_types_doesnt_cause_retrace (line 530) | def test_metric_empty_types_doesnt_cause_retrace(self):
    method test_tensor_aggregation_metrics_with_masks (line 569) | def test_tensor_aggregation_metrics_with_masks(

FILE: clu/parameter_overview.py
  class _ParamRow (line 32) | class _ParamRow:
  class _ParamRowWithSharding (line 40) | class _ParamRowWithSharding(_ParamRow):
  class _ParamRowWithStats (line 45) | class _ParamRowWithStats(_ParamRow):
  class _ParamRowWithStatsAndSharding (line 51) | class _ParamRowWithStatsAndSharding(_ParamRowWithStats):
  function _mean_std_jit (line 56) | def _mean_std_jit(x):
  function _mean_std (line 60) | def _mean_std(x):
  function flatten_dict (line 66) | def flatten_dict(
  function _count_parameters (line 82) | def _count_parameters(params: _ParamsContainer) -> int:
  function _parameters_size (line 88) | def _parameters_size(params: _ParamsContainer) -> int:
  function count_parameters (line 98) | def count_parameters(params: _ParamsContainer) -> int:
  function _make_row (line 104) | def _make_row(name, value) -> _ParamRow:
  function _make_row_with_sharding (line 120) | def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
  function _make_row_with_stats (line 132) | def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
  function _make_row_with_stats_and_sharding (line 143) | def _make_row_with_stats_and_sharding(
  function _get_parameter_rows (line 154) | def _get_parameter_rows(
  function _default_table_value_formatter (line 209) | def _default_table_value_formatter(value):
  function make_table (line 221) | def make_table(
  function _get_parameter_overview (line 286) | def _get_parameter_overview(
  function get_parameter_overview (line 310) | def get_parameter_overview(
  function _log_parameter_overview (line 347) | def _log_parameter_overview(
  function log_parameter_overview (line 368) | def log_parameter_overview(

FILE: clu/parameter_overview_test.py
  class CNN (line 72) | class CNN(nn.Module):
    method __call__ (line 75) | def __call__(self, x):
  class JaxParameterOverviewTest (line 79) | class JaxParameterOverviewTest(absltest.TestCase):
    method test_count_parameters_empty (line 81) | def test_count_parameters_empty(self):
    method test_count_parameters (line 84) | def test_count_parameters(self):
    method test_get_parameter_overview_empty (line 92) | def test_get_parameter_overview_empty(self):
    method test_get_parameter_overview (line 98) | def test_get_parameter_overview(self):
    method test_get_parameter_overview_shape_dtype_struct (line 126) | def test_get_parameter_overview_shape_dtype_struct(self):
    method test_printing_bool (line 134) | def test_printing_bool(self):

FILE: clu/periodic_actions.py
  function _squareit (line 44) | def _squareit(x):
  function _format_secs (line 49) | def _format_secs(secs: float):
  class PeriodicAction (line 65) | class PeriodicAction(abc.ABC):
    method __init__ (line 74) | def __init__(self,
    method _init_and_check (line 99) | def _init_and_check(self, step: int, t: float):
    method _should_trigger (line 112) | def _should_trigger(self, step: int, t: float) -> bool:
    method _after_apply (line 123) | def _after_apply(self, step: int, t: float):
    method __call__ (line 128) | def __call__(self, step: int, t: Optional[float] = None) -> bool:
    method _apply (line 150) | def _apply(self, step: int, t: float):
  class ReportProgress (line 154) | class ReportProgress(PeriodicAction):
    method __init__ (line 157) | def __init__(self,
    method set_persistent_notes (line 197) | def set_persistent_notes(self, message: str):
    method _should_trigger (line 201) | def _should_trigger(self, step: int, t: float) -> bool:
    method _apply (line 205) | def _apply(self, step: int, t: float):
    method timed (line 227) | def timed(self, name: str, wait_jax_async_dispatch: bool = True):
  class Profile (line 304) | class Profile(PeriodicAction):
    method __init__ (line 309) | def __init__(
    method _should_trigger (line 350) | def _should_trigger(self, step: int, t: float) -> bool:
    method _apply (line 364) | def _apply(self, step: int, t: float):
    method _start_session (line 368) | def _start_session(self):
    method _end_session (line 376) | def _end_session(self, url: Optional[str]):
  class ProfileAllHosts (line 385) | class ProfileAllHosts(PeriodicAction):
    method __init__ (line 390) | def __init__(self,
    method _should_trigger (line 419) | def _should_trigger(self, step: int, t: float) -> bool:
    method _apply (line 422) | def _apply(self, step: int, t: float):
    method _start_session (line 426) | def _start_session(self):
    method _end_session (line 435) | def _end_session(self, url: Optional[str], *, step: int):
  class PeriodicCallback (line 443) | class PeriodicCallback(PeriodicAction):
    method __init__ (line 446) | def __init__(self,
    method __call__ (line 476) | def __call__(self, step: int, t: Optional[float] = None, **kwargs) -> ...
    method get_last_callback_result (line 488) | def get_last_callback_result(self):
    method _apply (line 492) | def _apply(self, step, t, **kwargs):

FILE: clu/periodic_actions_test.py
  class ReportProgressTest (line 26) | class ReportProgressTest(parameterized.TestCase):
    method test_every_steps (line 28) | def test_every_steps(self):
    method test_every_secs (line 50) | def test_every_secs(self):
    method test_without_num_train_steps (line 72) | def test_without_num_train_steps(self):
    method test_with_persistent_notes (line 83) | def test_with_persistent_notes(self):
    method test_unknown_cardinality (line 96) | def test_unknown_cardinality(self):
    method test_called_every_step (line 107) | def test_called_every_step(self):
    method test_named (line 121) | def test_named(self, wait_jax_async_dispatch, mock_time):
    method test_write_metrics (line 156) | def test_write_metrics(self, time_mock):
  class DummyProfilerSession (line 175) | class DummyProfilerSession:
    method __init__ (line 178) | def __init__(self):
    method start_session (line 183) | def start_session(self):
    method end_session_and_get_url (line 186) | def end_session_and_get_url(self, tag):
  class ProfileTest (line 191) | class ProfileTest(absltest.TestCase):
    method test_every_steps (line 195) | def test_every_steps(self, mock_time, mock_profiler):
  class ProfileAllHostsTest (line 224) | class ProfileAllHostsTest(absltest.TestCase):
    method test_every_steps (line 227) | def test_every_steps(self, mock_profiler):
  class PeriodicCallbackTest (line 247) | class PeriodicCallbackTest(absltest.TestCase):
    method test_every_steps (line 249) | def test_every_steps(self):
    method test_every_secs (line 267) | def test_every_secs(self, mock_time):
    method test_on_steps (line 281) | def test_on_steps(self):
    method test_async_execution (line 290) | def test_async_execution(self):
    method test_error_async_is_forwarded (line 309) | def test_error_async_is_forwarded(self):
    method test_function_without_step_and_time (line 325) | def test_function_without_step_and_time(self):

FILE: clu/platform/__init__.py
  function work_unit (line 35) | def work_unit() -> WorkUnit:

FILE: clu/platform/interface.py
  class ArtifactType (line 22) | class ArtifactType(enum.Enum):
  class WorkUnit (line 31) | class WorkUnit(abc.ABC):
    method experiment_id (line 42) | def experiment_id(self):
    method id (line 47) | def id(self):
    method name (line 51) | def name(self):
    method set_notes (line 63) | def set_notes(self, msg: str):
    method set_task_status (line 67) | def set_task_status(self, msg: str):
    method create_artifact (line 71) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,

FILE: clu/platform/local.py
  class LocalWorkUnit (line 26) | class LocalWorkUnit(WorkUnit):
    method experiment_id (line 30) | def experiment_id(self):
    method id (line 35) | def id(self):
    method set_notes (line 39) | def set_notes(self, msg: str):
    method set_task_status (line 43) | def set_task_status(self, msg: str):
    method create_artifact (line 47) | def create_artifact(self, artifact_type: ArtifactType, artifact: Any,

FILE: clu/preprocess_spec.py
  class PreprocessOp (line 77) | class PreprocessOp(Protocol):
    method __call__ (line 90) | def __call__(self, features: Features) -> Features:
  class MapTransform (line 95) | class MapTransform(abc.ABC):
    method __new__ (line 109) | def __new__(cls, *args, **kwargs):
    method __call__ (line 121) | def __call__(self, features: D) -> D:
    method _transform (line 130) | def _transform(self, features: FlatFeatures) -> FlatFeatures:
  class RandomMapTransform (line 135) | class RandomMapTransform(MapTransform, abc.ABC):
    method __call__ (line 149) | def __call__(self, features: D) -> D:
    method _transform (line 162) | def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatF...
  class FilterTransform (line 167) | class FilterTransform(abc.ABC):
    method __call__ (line 169) | def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    method _predicate (line 175) | def _predicate(self, features: FlatFeatures) -> tf.Tensor:
  function get_all_ops (line 179) | def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]:
  function _jax_supported_tf_types (line 204) | def _jax_supported_tf_types():
  class OnlyJaxTypes (line 214) | class OnlyJaxTypes:
    method __call__ (line 228) | def __call__(self, features: Features) -> Features:
  class PreprocessFn (line 252) | class PreprocessFn:
    method __call__ (line 265) | def __call__(self, features: Features) -> Features:
    method __add__ (line 280) | def __add__(self, other: "PreprocessFn") -> "PreprocessFn":
    method __getitem__ (line 289) | def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn":
  function _get_op_class (line 298) | def _get_op_class(
  function _parse_single_preprocess_op (line 315) | def _parse_single_preprocess_op(
  function parse (line 359) | def parse(spec: str,
  function _describe_features (line 374) | def _describe_features(features: Features) -> str:

FILE: clu/preprocess_spec_test.py
  class ToFloat (line 27) | class ToFloat:
    method __call__ (line 29) | def __call__(self, features: Features) -> Features:
  class Rescale (line 34) | class Rescale:
    method __call__ (line 38) | def __call__(self, features: Features) -> Features:
  class AddRandomInteger (line 45) | class AddRandomInteger(preprocess_spec.RandomMapTransform):
    method _transform (line 47) | def _transform(self, features, seed):
  class PreprocessSpecTest (line 55) | class PreprocessSpecTest(parameterized.TestCase, tf.test.TestCase):
    method test_no_arguments (line 58) | def test_no_arguments(self):
    method test_positional_argument (line 63) | def test_positional_argument(self):
    method test_keyword_argument (line 69) | def test_keyword_argument(self):
    method test_invalid_op_name (line 75) | def test_invalid_op_name(self):
    method test_invalid_spec (line 82) | def test_invalid_spec(self):
    method test_pos_and_kw_arg (line 87) | def test_pos_and_kw_arg(self):
    method test_parsing_empty_string (line 95) | def test_parsing_empty_string(self):
    method test_multi_op_spec (line 100) | def test_multi_op_spec(self):
    method test_two_tensors (line 105) | def test_two_tensors(self):
    method test_only_jax_types (line 114) | def test_only_jax_types(self):
    method test_only_jax_types_nested_inputs (line 128) | def test_only_jax_types_nested_inputs(self):
    method test_not_only_jax_types (line 139) | def test_not_only_jax_types(self):
    method test_add_preprocess_fn (line 145) | def test_add_preprocess_fn(self):
    method test_slice_preprocess_fn (line 156) | def test_slice_preprocess_fn(self):
    method test_random_map_transform (line 166) | def test_random_map_transform(self):

FILE: clu/profiler.py
  function start (line 29) | def start(logdir: str, options=None):
  function stop (line 41) | def stop() -> Optional[str]:
  function collect (line 49) | def collect(logdir: str,

FILE: clu/values.py
  class Value (line 31) | class Value(Protocol):
  class Summary (line 41) | class Summary(Value):
  class Scalar (line 47) | class Scalar(Value):
  class Image (line 52) | class Image(Value):
  class Audio (line 65) | class Audio(Value):
  class Text (line 81) | class Text(Value):
  class Histogram (line 86) | class Histogram(Value):
  class HyperParam (line 93) | class HyperParam(Value):
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,228K chars).
[
  {
    "path": ".github/workflows/build.yml",
    "chars": 1049,
    "preview": "# This workflow will install Python dependencies, run tests and lint.\n# For more information see: https://help.github.co"
  },
  {
    "path": ".github/workflows/python-publish.yml",
    "chars": 846,
    "preview": "# This workflows will upload a Python Package using Twine when a release is created\n# For more information see: https://"
  },
  {
    "path": "AUTHORS",
    "chars": 307,
    "preview": "# This is the list of Common Loop Utils significant contributors.\n#\n# This does not necessarily list everyone who has co"
  },
  {
    "path": "CHANGELOG.md",
    "chars": 3390,
    "preview": "# Changelog\n\n## v0.0.1-alpha.1\n\nInitial PyPi Release\n\nCurrent list of modules:\n\n-   `clu.checkpoint`\n-   `clu.determinis"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 923,
    "preview": "# How to Contribute\n\nAt this time we are focused on supporting research done by Google Research and\nare not accepting pa"
  },
  {
    "path": "LICENSE",
    "chars": 11358,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 761,
    "preview": "# CLU - Common Loop Utils\n\nThis repository contains common functionality for writing ML training loops. The\ngoal is to m"
  },
  {
    "path": "clu/__init__.py",
    "chars": 581,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/asynclib.py",
    "chars": 5155,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/asynclib_test.py",
    "chars": 3341,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/checkpoint.py",
    "chars": 21322,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/checkpoint_test.py",
    "chars": 14741,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/data/__init__.py",
    "chars": 892,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/data/dataset_iterator.py",
    "chars": 9756,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/data/dataset_iterator_test.py",
    "chars": 4146,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/deterministic_data.py",
    "chars": 24040,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/deterministic_data_test.py",
    "chars": 13529,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/internal/__init__.py",
    "chars": 581,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/internal/utils.py",
    "chars": 3249,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/internal/utils_test.py",
    "chars": 3391,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/__init__.py",
    "chars": 2411,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/async_writer.py",
    "chars": 5496,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/async_writer_test.py",
    "chars": 5076,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/interface.py",
    "chars": 7012,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/logging_writer.py",
    "chars": 5536,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/logging_writer_test.py",
    "chars": 5045,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/multi_writer.py",
    "chars": 2814,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/multi_writer_test.py",
    "chars": 2289,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/summary_writer.py",
    "chars": 711,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/tf/__init__.py",
    "chars": 610,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/tf/summary_writer.py",
    "chars": 4285,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/tf/summary_writer_test.py",
    "chars": 7805,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/torch_tensorboard_writer.py",
    "chars": 3285,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/torch_tensorboard_writer_test.py",
    "chars": 3341,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/utils.py",
    "chars": 5308,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metric_writers/utils_test.py",
    "chars": 5967,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metrics.py",
    "chars": 32656,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/metrics_test.py",
    "chars": 20817,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/parameter_overview.py",
    "chars": 12445,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/parameter_overview_test.py",
    "chars": 6122,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/periodic_actions.py",
    "chars": 18561,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/periodic_actions_test.py",
    "chars": 9895,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/platform/__init__.py",
    "chars": 1344,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/platform/interface.py",
    "chars": 2075,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/platform/local.py",
    "chars": 1634,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/preprocess_spec.py",
    "chars": 14204,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/preprocess_spec_test.py",
    "chars": 6550,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/profiler.py",
    "chars": 1826,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu/run_pytest.google.sh",
    "chars": 373,
    "preview": "#!/bin/bash\n\nset -e -x\n\nCLU_DST=\"${CLU_DST:-/tmp/clu}\"\nCLU_ENV=\"${CLU_ENV:-/tmp/clu_env}\"\n\ncopybara third_party/py/clu/c"
  },
  {
    "path": "clu/values.py",
    "chars": 2772,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  },
  {
    "path": "clu_synopsis.ipynb",
    "chars": 882835,
    "preview": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"4CldxEhqQac_\"\n      },\n      \"sou"
  },
  {
    "path": "setup.py",
    "chars": 2052,
    "preview": "# Copyright 2026 The CLU Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use "
  }
]

About this extraction

This page contains the full source code of the google/CommonLoopUtils GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (1.2 MB), approximately 657.2k tokens, and a symbol index with 535 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!