[
  {
    "path": ".clang-format",
    "content": "BasedOnStyle: Google"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n*_pb2.py\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# protobuf stuff\ntf/proto/\n**/*_pb2.py\n**/*_pb2.pyi\n\n# Meson\nbuilddir/\nsubprojects/\n\n# LLM Agents\n# Use AGENTS.md; symlink other files to it if needed.\n.claude/\nCLAUDE.local.md\n.claude.json\nCLAUDE.md\nGEMINI.md\n\n# IDE\n.vscode/\n.idea/"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"libs/lc0\"]\n\tpath = libs/lc0\n\turl = https://github.com/LeelaChessZero/lc0.git\n"
  },
  {
    "path": "AGENTS.md",
    "content": "# AGENTS.md\n\nThis repository contains training script for the Leela Chess Zero project.\nThey are being rewritten.\n\n* Old code is located in the `tf/` directory.\n* New python code is located in the `src/` directory.\n* New C++ code is located in the `csrc/` directory.\n\nThe old code is Python/TensorFlow-based, new code is Python/JAX-based with\nmodules written in C++, operating through pybind11.\n\nThe build system for C++ code is meson. During development, the project is built\nin the `builddir/`.\n\n## Testing and Building\n\n* C++ tests use GTest framework\n  * Do not insert Sleeps in tests, it slows down presubmit. Instead use e.g.\n    absl::Notification, or std::future\n* Tests are defined in `meson.build` with `test()` function\n  * When debugging, don't forget to build them before running `meson test` or\n    `builddir/test`\n* Run tests: `meson test -C builddir/`\n* Python tests use `pytest` framework\n  * Do not add custom main function, exception catching to report errors, any\n    \"test passed\" messages etc. Use `pytest` fixtures and assertions.\n* Build: `meson compile -C builddir/` from build directory\n* Format code: `just format`\n* There is a commit hook that runs `just pre-commit`, which runs tests and\n  checks formatting. You may want to run it before attempting to commit.\n* We use Google C++ style guide.\n  * That means 80 columns.\n  * That means comments should be in full sentences with periods in the end.\n  * When conditional or loop fits one line, it must be written as one line\n    without braces, for example:\n      `if (condition) return value;`\n  * Prefer `absl` to `std` (e.g. `absl::c_` algorithms, `absl::Mutex`,\n    `absl::StrCat`, etc.)\n* We use `uv` for Python package and venv management, and to running the\n  application.\n* Run TUI app: `uv run tui --config=<path_to_config>`\n\n* Do not attempt to run TUI — it messes up the Agent interface and session has\n  to be killed. Ask me to check it for you manually instead.\n\n* Do not commit unless explicitly asked.\n\n## IMPORTANT\n\n* NEVER add `# type: ignore` or other ways to mask/silence errors instead of\n  fixing them.\n* Rely on protobuf default values. DO NOT write code like\n  `config.has_foo() ? config.foo() : default_value;`\n\n## Documentation\n\n* Documentation is in the `docs/` directory.\n* The contents is in [The index](docs/index.md)\n"
  },
  {
    "path": "README.md",
    "content": "# Training\n\nThe training pipeline resides in `tf`, this requires tensorflow running on linux (Ubuntu 16.04 in this case). (It can be made to work on windows too, but it takes more effort.)\n\n## Installation\n\nInstall the requirements under `tf/requirements.txt`. And call `./init.sh` to compile the protobuf files.\n\n## Data preparation\n\nIn order to start a training session you first need to download training data from https://storage.lczero.org/files/training_data/. Several chunks/games are packed into a tar file, and each tar file contains an hour worth of chunks. Preparing data requires the following steps:\n\n```\nwget https://storage.lczero.org/files/training_data/training-run1--20200711-2017.tar\ntar -xzf training-run1--20200711-2017.tar\n```\n\n## Training pipeline\n\nNow that the data is in the right format one can configure a training pipeline. This configuration is achieved through a yaml file, see `training/tf/configs/example.yaml`:\n\n```yaml\n%YAML 1.2\n---\nname: 'kb1-64x6'                       # ideally no spaces\ngpu: 0                                 # gpu id to process on\n\ndataset:\n  num_chunks: 100000                   # newest nof chunks to parse\n  train_ratio: 0.90                    # trainingset ratio\n  # For separated test and train data.\n  input_train: '/path/to/chunks/*/draw/' # supports glob\n  input_test: '/path/to/chunks/*/draw/'  # supports glob\n  # For a one-shot run with all data in one directory.\n  # input: '/path/to/chunks/*/draw/'\n\ntraining:\n    batch_size: 2048                   # training batch\n    total_steps: 140000                # terminate after these steps\n    test_steps: 2000                   # eval test set values after this many steps\n    # checkpoint_steps: 10000          # optional frequency for checkpointing before finish\n    shuffle_size: 524288               # size of the shuffle buffer\n    lr_values:                         # list of learning rates\n        - 0.02\n        - 0.002\n        - 0.0005\n    lr_boundaries:                     # list of boundaries\n        - 100000\n        - 130000\n    policy_loss_weight: 1.0            # weight of policy loss\n    value_loss_weight: 1.0             # weight of value loss\n    path: '/path/to/store/networks'    # network storage dir\n\nmodel:\n  filters: 64\n  residual_blocks: 6\n...\n```\n\nThe configuration is pretty self explanatory, if you're new to training I suggest looking at the [machine learning glossary](https://developers.google.com/machine-learning/glossary/) by google. Now you can invoke training with the following command:\n\n```bash\n./train.py --cfg configs/example.yaml --output /tmp/mymodel.txt\n```\n\nThis will initialize the pipeline and start training a new neural network. You can view progress by invoking tensorboard:\n\n```bash\ntensorboard --logdir leelalogs\n```\n\nIf you now point your browser at localhost:6006 you'll see the trainingprogress as the trainingsteps pass by. Have fun!\n\n## Restoring models\n\nThe training pipeline will automatically restore from a previous model if it exists in your `training:path` as configured by your yaml config. For initializing from a raw `weights.txt` file you can use `training/tf/net_to_model.py`, this will create a checkpoint for you.\n\n## Supervised training\n\nGenerating trainingdata from pgn files is currently broken and has low priority, feel free to create a PR.\n\n## Building 2025-08 version.\n\n1. Make sure `uv` and `justfile` are installed (plus `meson` and other stuff, potentially `protoc`).\n2. `git submodule update`\n3. `uv venv` (!important! do this before running meson; otherwise meson will build module for wrong python)\n4. `uv sync`\n5. `CXX=clang++ CC=clang uv run meson setup build/release/ --buildtype=release --native-file=native.ini` (clang is optional, should build fine with default compiler)\n6. `just build-proto`\n7. `meson compile -C build/release/`\n8. `ln -s -T ../../build/release/_lczero_training.cpython-311-x86_64-linux-gnu.so src/lczero_training/_lczero_training.so`\n\n14. Run it! `uv run tui --config docs/example.textproto`\n"
  },
  {
    "path": "csrc/loader/chunk_source/chunk_source.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <optional>\n#include <string>\n#include <vector>\n\n#include \"loader/frame_type.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Interface for providing training data chunks.\n// A chunk source provides access to one or more chunks of training data.\n// It's assumed that all chunks in a source for one group for sorting purposes,\n// therefore GetChunkSortKey() returns just one key for the entire source. This\n// allows to know the key before reading/indexing the chunks.\nclass ChunkSource {\n public:\n  virtual ~ChunkSource() = default;\n\n  // Returns a sort key (e.g. filename or a timestamp).\n  virtual std::string GetChunkSortKey() const = 0;\n\n  // Returns the number of chunks in this source.\n  virtual size_t GetChunkCount() const = 0;\n\n  // Returns the data for the chunk at the given index. Returns std::nullopt if\n  // the chunk could not be read or if the data size is not a multiple of the\n  // expected frame size.\n  virtual std::optional<std::vector<FrameType>> GetChunkData(size_t index) = 0;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/chunk_source_view.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <memory>\n#include <optional>\n#include <string>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n\nnamespace lczero {\nnamespace training {\n\n// ChunkSourceView provides a view over another ChunkSource.\n// It exposes a remapped subset/order of chunks defined by indices into the\n// underlying source. It does not own or copy the data; it forwards calls to\n// the wrapped source.\nclass ChunkSourceView : public ChunkSource {\n public:\n  // Constructs a view over an existing chunk source. The indices vector maps\n  // local indices in the view to indices of the underlying source.\n  ChunkSourceView(std::shared_ptr<ChunkSource> source,\n                  std::vector<uint32_t> indices)\n      : source_(std::move(source)), indices_(std::move(indices)) {}\n\n  ~ChunkSourceView() override = default;\n\n private:\n  std::string GetChunkSortKey() const override {\n    return source_->GetChunkSortKey();\n  }\n\n  size_t GetChunkCount() const override { return indices_.size(); }\n\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {\n    if (index >= indices_.size()) return std::nullopt;\n    const size_t src_index = static_cast<size_t>(indices_[index]);\n    return source_->GetChunkData(src_index);\n  }\n\n  std::shared_ptr<ChunkSource> source_;\n  std::vector<uint32_t> indices_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/debug_chunk_source.cc",
    "content": "#include \"loader/chunk_source/debug_chunk_source.h\"\n\n#include <algorithm>\n#include <cinttypes>\n#include <cmath>\n#include <cstring>\n#include <random>\n#include <utility>\n\n#include \"absl/hash/hash.h\"\n#include \"absl/strings/str_format.h\"\n\nnamespace lczero {\nnamespace training {\n\nDebugChunkSource::DebugChunkSource(uint64_t id, double mean_chunk_count)\n    : id_(id), mean_chunk_count_(mean_chunk_count) {}\n\nstd::string DebugChunkSource::GetChunkSortKey() const {\n  return absl::StrFormat(\"%08\" PRIu64, id_);\n}\n\nsize_t DebugChunkSource::GetChunkCount() const {\n  if (!cached_chunk_count_.has_value()) {\n    std::mt19937_64 rng(id_);\n    const double stddev = std::max(1.0, mean_chunk_count_ / 4.0);\n    std::normal_distribution<double> distribution(mean_chunk_count_, stddev);\n    const double sampled = distribution(rng);\n    const auto rounded =\n        static_cast<long long>(std::llround(std::max(sampled, 1.0)));\n    cached_chunk_count_ = static_cast<size_t>(rounded);\n  }\n  return *cached_chunk_count_;\n}\n\nstd::optional<std::vector<FrameType>> DebugChunkSource::GetChunkData(\n    size_t index) {\n  const auto seed_pair = std::make_pair(id_, index);\n  const uint64_t seed = static_cast<uint64_t>(\n      absl::Hash<std::pair<uint64_t, size_t>>{}(seed_pair));\n  std::mt19937_64 rng(seed);\n  std::uniform_int_distribution<int> frame_count_distribution(1, 200);\n  const int frame_count = frame_count_distribution(rng);\n\n  std::vector<FrameType> result(frame_count);\n  for (int frame_index = 0; frame_index < frame_count; ++frame_index) {\n    result[frame_index] = FrameType{};\n    result[frame_index].planes[0] = static_cast<uint64_t>(id_);\n    result[frame_index].planes[1] = static_cast<uint64_t>(index);\n    result[frame_index].planes[2] = static_cast<uint64_t>(frame_index);\n  }\n\n  return result;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/debug_chunk_source.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <optional>\n#include <random>\n#include <string>\n\n#include \"loader/chunk_source/chunk_source.h\"\n\nnamespace lczero {\nnamespace training {\n\n// DebugChunkSource synthesizes deterministic pseudo-random chunks for loader\n// debugging. Each instance is identified by an integer id. The class produces\n// a chunk count sampled from a normal distribution with the provided mean and\n// mean / 4 standard deviation. The id serves as the seed, which keeps the\n// number of chunks stable across runs. Individual chunks contain a\n// pseudo-random number of FrameType frames (between one and 200) that are\n// generated on demand. The generation seed depends on both the source id and\n// chunk index. This lets shuffling logic exercise variable chunk sizes while\n// keeping the content reproducible. Each generated frame is zero-initialized,\n// but the first three entries of the planes array encode, respectively, the\n// source id, the chunk index, and the frame index within the chunk. This makes\n// it easy to reason about ordering and grouping when inspecting chunk payloads.\nclass DebugChunkSource : public ChunkSource {\n public:\n  DebugChunkSource(uint64_t id, double mean_chunk_count);\n\n private:\n  std::string GetChunkSortKey() const override;\n  size_t GetChunkCount() const override;\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;\n\n  uint64_t id_;\n  double mean_chunk_count_;\n  mutable std::optional<size_t> cached_chunk_count_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/rawfile_chunk_source.cc",
    "content": "#include \"loader/chunk_source/rawfile_chunk_source.h\"\n\n#include <absl/log/log.h>\n\n#include <fstream>\n#include <stdexcept>\n\n#include \"trainingdata/trainingdata_v6.h\"\n#include \"utils/files.h\"\n#include \"utils/gz.h\"\n\nnamespace lczero {\nnamespace training {\n\nRawFileChunkSource::RawFileChunkSource(\n    const std::filesystem::path& filename,\n    ChunkSourceLoaderConfig::FrameFormat frame_format)\n    : filename_(filename), frame_format_(frame_format) {}\n\nRawFileChunkSource::~RawFileChunkSource() = default;\n\nstd::string RawFileChunkSource::GetChunkSortKey() const {\n  return std::filesystem::path(filename_).filename().string();\n}\n\nsize_t RawFileChunkSource::GetChunkCount() const { return 1; }\n\nstd::optional<std::vector<FrameType>> RawFileChunkSource::GetChunkData(\n    size_t index) {\n  if (index != 0) return std::nullopt;\n  std::string data = ReadFileToString(filename_);\n  if (data.empty()) return std::nullopt;\n\n  const size_t input_size =\n      frame_format_ == ChunkSourceLoaderConfig::V7TrainingData\n          ? sizeof(V7TrainingData)\n          : sizeof(V6TrainingData);\n  if (data.size() % input_size != 0) {\n    LOG(WARNING) << \"File \" << filename_ << \" size \" << data.size()\n                 << \" is not a multiple of input frame size \" << input_size;\n    return std::nullopt;\n  }\n\n  const size_t num_frames = data.size() / input_size;\n  std::vector<V7TrainingData> result(num_frames);\n\n  if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) {\n    std::memcpy(result.data(), data.data(), data.size());\n  } else {\n    const auto* v6_data = reinterpret_cast<const V6TrainingData*>(data.data());\n    for (size_t i = 0; i < num_frames; ++i) {\n      std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData));\n    }\n  }\n  return result;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/rawfile_chunk_source.h",
    "content": "#pragma once\n\n#include <filesystem>\n#include <string>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\n// A chunk source that reads a single (potentially gzipped) file as a single\n// chunk.\nclass RawFileChunkSource : public ChunkSource {\n public:\n  RawFileChunkSource(const std::filesystem::path& filename,\n                     ChunkSourceLoaderConfig::FrameFormat frame_format);\n  ~RawFileChunkSource();\n\n private:\n  std::string GetChunkSortKey() const override;\n  size_t GetChunkCount() const override;\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;\n\n  std::string filename_;\n  ChunkSourceLoaderConfig::FrameFormat frame_format_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/tar_chunk_source.cc",
    "content": "#include \"loader/chunk_source/tar_chunk_source.h\"\n\n#include <absl/log/log.h>\n#include <absl/strings/str_cat.h>\n#include <zlib.h>\n\n#include <algorithm>\n#include <array>\n#include <cassert>\n#include <cstdint>\n#include <cstring>\n#include <fcntl.h>\n#include <stdexcept>\n#include <unistd.h>\n\n#include \"trainingdata/trainingdata_v6.h\"\n#include \"utils/gz.h\"\n\nnamespace lczero {\nnamespace training {\nnamespace {\nstruct TarHeader {\n  std::array<char, 100> name;\n  std::array<uint8_t, 8> mode;\n  std::array<uint8_t, 8> uid;\n  std::array<uint8_t, 8> gid;\n  std::array<uint8_t, 12> size;\n  std::array<uint8_t, 12> mtime;\n  std::array<uint8_t, 8> chksum;\n  uint8_t typeflag;\n  std::array<uint8_t, 100> linkname;\n  std::array<uint8_t, 6> magic;\n  std::array<uint8_t, 2> version;\n  std::array<uint8_t, 32> uname;\n  std::array<uint8_t, 32> gname;\n  std::array<uint8_t, 8> devmajor;\n  std::array<uint8_t, 8> devminor;\n  std::array<uint8_t, 155> prefix;\n  std::array<uint8_t, 12> padding;\n};\nstatic_assert(sizeof(TarHeader) == 512, \"TarHeader must be exactly 512 bytes\");\n\nuint64_t ParseOctal(const std::array<uint8_t, 12>& octal) {\n  uint64_t value = 0;\n  for (uint8_t digit : octal) {\n    if (!digit) break;\n    value = (value << 3) + (digit - '0');\n  }\n  return value;\n}\n\nbool ReadExact(int fd, off_t offset, void* buffer, size_t size) {\n  char* out = static_cast<char*>(buffer);\n  size_t read_total = 0;\n  while (read_total < size) {\n    const ssize_t read_now =\n        pread(fd, out + read_total, size - read_total, offset + read_total);\n    if (read_now <= 0) return false;\n    read_total += static_cast<size_t>(read_now);\n  }\n  return true;\n}\n\nstd::optional<std::string> ReadGzipPrefix(int fd, off_t offset, size_t size,\n                                          size_t max_bytes) {\n  if (max_bytes == 0) return std::string();\n\n  z_stream strm = {};\n  if (inflateInit2(&strm, 16 + MAX_WBITS) != Z_OK) {\n    return std::nullopt;\n  }\n\n  constexpr size_t kChunkSize = 16384;\n  std::array<uint8_t, kChunkSize> input_buffer;\n  std::array<char, kChunkSize> output_buffer;\n\n  std::string output;\n  output.reserve(std::min<size_t>(max_bytes, kChunkSize));\n\n  size_t remaining = size;\n  off_t current_offset = offset;\n  bool finished = false;\n\n  while (remaining > 0 && !finished && output.size() < max_bytes) {\n    const size_t to_read = std::min(remaining, kChunkSize);\n    if (!ReadExact(fd, current_offset, input_buffer.data(), to_read)) {\n      inflateEnd(&strm);\n      return std::nullopt;\n    }\n    remaining -= to_read;\n    current_offset += static_cast<off_t>(to_read);\n\n    strm.next_in = reinterpret_cast<Bytef*>(input_buffer.data());\n    strm.avail_in = static_cast<uInt>(to_read);\n\n    while (strm.avail_in > 0 && output.size() < max_bytes) {\n      strm.next_out = reinterpret_cast<Bytef*>(output_buffer.data());\n      strm.avail_out = kChunkSize;\n\n      const int ret = inflate(&strm, Z_NO_FLUSH);\n      if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT || ret == Z_DATA_ERROR ||\n          ret == Z_MEM_ERROR) {\n        inflateEnd(&strm);\n        return std::nullopt;\n      }\n\n      const size_t produced = kChunkSize - strm.avail_out;\n      const size_t to_copy = std::min(produced, max_bytes - output.size());\n      output.append(output_buffer.data(), to_copy);\n\n      if (ret == Z_STREAM_END) {\n        finished = true;\n        break;\n      }\n    }\n  }\n\n  inflateEnd(&strm);\n  return output;\n}\n\n}  // namespace\n\nTarChunkSource::TarChunkSource(\n    const std::filesystem::path& filename,\n    ChunkSourceLoaderConfig::FrameFormat frame_format)\n    : path_(filename),\n      filename_(filename.filename().string()),\n      frame_format_(frame_format) {\n  fd_ = open(path_.c_str(), O_RDONLY | O_CLOEXEC);\n  if (fd_ < 0) {\n    throw std::runtime_error(\n        absl::StrCat(\"Failed to open tar file: \", path_.string(), \": \",\n                     std::strerror(errno)));\n  }\n  // Perform indexing during construction.\n  Index();\n}\n\nTarChunkSource::~TarChunkSource() { Close(); }\n\nvoid TarChunkSource::Close() {\n  if (fd_ >= 0 && close(fd_) != 0) {\n    PLOG(WARNING) << \"Failed to close tar file descriptor for \" << path_;\n  }\n  fd_ = -1;\n}\n\nstd::string TarChunkSource::GetChunkSortKey() const { return filename_; }\n\nvoid TarChunkSource::Index() {\n  assert(files_.empty());\n\n  off_t offset = 0;\n\n  while (true) {\n    TarHeader header;\n    if (!ReadExact(fd_, offset, &header, sizeof(header))) {\n      LOG(WARNING) << \"Truncated tar file: \" << filename_;\n      break;\n    }\n    offset += sizeof(header);\n\n    if (header.name[0] == '\\0') break;  // End of file.\n\n    switch (header.typeflag) {\n      case '5':  // Directory\n        continue;\n      case '0':  // Regular file\n        break;\n      default:\n        LOG(WARNING) << \"Unsupported tar header type: \" << header.typeflag;\n        continue;\n    }\n\n    std::string_view fname(const_cast<const char*>(header.name.data()));\n    const std::filesystem::path filepath = std::filesystem::path(fname);\n    const off_t file_offset = offset;\n    const size_t size = ParseOctal(header.size);\n    const size_t padded_size = ((size + 511) / 512) * 512;\n    offset = file_offset + static_cast<off_t>(padded_size);\n\n    if (filepath.filename() == \"LICENSE\") continue;\n    files_.push_back({file_offset, size, filepath.extension() == \".gz\"});\n  }\n\n  LOG(INFO) << \"Read \" << files_.size() << \" entries from \" << filename_;\n}\n\nsize_t TarChunkSource::GetChunkCount() const { return files_.size(); }\n\nstd::optional<std::vector<FrameType>> TarChunkSource::GetChunkData(\n    size_t index) {\n  if (index >= files_.size()) {\n    throw std::out_of_range(\"File index out of range\");\n  }\n  const auto& file_entry = files_[index];\n  std::string content(file_entry.size, '\\0');\n  if (!ReadExact(fd_, file_entry.offset, content.data(), file_entry.size)) {\n    return std::nullopt;\n  }\n  if (file_entry.is_gzip) {\n    try {\n      content = GunzipBuffer(content);\n    } catch (const GunzipError& e) {\n      return std::nullopt;\n    }\n  }\n  if (content.empty()) return std::nullopt;\n\n  const size_t input_size =\n      frame_format_ == ChunkSourceLoaderConfig::V7TrainingData\n          ? sizeof(V7TrainingData)\n          : sizeof(V6TrainingData);\n  if (content.size() % input_size != 0) {\n    LOG(WARNING) << \"Chunk \" << index << \" from \" << filename_ << \" size \"\n                 << content.size() << \" is not a multiple of input frame size \"\n                 << input_size;\n    return std::nullopt;\n  }\n\n  const size_t num_frames = content.size() / input_size;\n  std::vector<V7TrainingData> result(num_frames);\n\n  if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) {\n    std::memcpy(result.data(), content.data(), content.size());\n  } else {\n    const auto* v6_data =\n        reinterpret_cast<const V6TrainingData*>(content.data());\n    for (size_t i = 0; i < num_frames; ++i) {\n      std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData));\n    }\n  }\n  return result;\n}\n\nstd::optional<std::string> TarChunkSource::GetChunkPrefix(size_t index,\n                                                          size_t max_bytes) {\n  if (index >= files_.size()) {\n    throw std::out_of_range(\"File index out of range\");\n  }\n  const auto& file_entry = files_[index];\n  if (file_entry.is_gzip) {\n    return ReadGzipPrefix(fd_, file_entry.offset, file_entry.size, max_bytes);\n  }\n\n  const size_t to_read = std::min(file_entry.size, max_bytes);\n  std::string content(to_read, '\\0');\n  if (!ReadExact(fd_, file_entry.offset, content.data(), to_read)) {\n    return std::nullopt;\n  }\n  return content;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/chunk_source/tar_chunk_source.h",
    "content": "#pragma once\n\n#include <sys/types.h>\n\n#include <filesystem>\n#include <string>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\n// A chunk source that reads a tar archive and provides access to its files as\n// chunks. Each file in the tar is treated as a separate chunk.\nclass TarChunkSource : public ChunkSource {\n public:\n  TarChunkSource(const std::filesystem::path& filename,\n                 ChunkSourceLoaderConfig::FrameFormat frame_format);\n  ~TarChunkSource() override;\n  std::string GetChunkSortKey() const override;\n  size_t GetChunkCount() const override;\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;\n  std::optional<std::string> GetChunkPrefix(size_t index, size_t max_bytes);\n\n private:\n  // Performs one-time indexing during construction. Not part of the interface.\n  void Index();\n  struct FileEntry {\n    off_t offset;\n    size_t size;\n    bool is_gzip;\n  };\n\n  void Close();\n\n  int fd_ = -1;\n  std::vector<FileEntry> files_;\n  std::filesystem::path path_;\n  std::string filename_;\n  ChunkSourceLoaderConfig::FrameFormat frame_format_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/data_loader.cc",
    "content": "#include \"loader/data_loader.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/log/log.h>\n#include <absl/strings/str_cat.h>\n#include <absl/strings/str_join.h>\n\n#include <chrono>\n#include <cmath>\n#include <optional>\n#include <stdexcept>\n\n#include \"loader/data_loader_metrics.h\"\n\nnamespace lczero {\nnamespace training {\nDataLoaderConfig DataLoader::ParseConfig(const std::string& serialized_config) {\n  DataLoaderConfig config;\n  config.ParseFromString(serialized_config);\n  return config;\n}\n\nDataLoader::DataLoader(const std::string& serialized_data_loader_config)\n    : metrics_aggregator_(\n          [](DataLoaderMetricsProto& m) { m.Clear(); },\n          [](DataLoaderMetricsProto& dest, const DataLoaderMetricsProto& src) {\n            UpdateFrom(dest, src);\n          }) {\n  DataLoaderConfig config = ParseConfig(serialized_data_loader_config);\n  AddStages(config);\n  BuildOutputMapping(config);\n  LOG(INFO) << \"DataLoader initialized with \" << stage_registry_.size()\n            << \" stage(s) and \" << outputs_.size() << \" output(s).\";\n}\n\nDataLoader::~DataLoader() { Stop(); }\n\nvoid DataLoader::AddStages(const std::string& serialized_data_loader_config) {\n  AddStages(ParseConfig(serialized_data_loader_config));\n}\n\nvoid DataLoader::AddStages(const DataLoaderConfig& config) {\n  for (const auto& stage_config : config.stage()) AddStage(stage_config);\n  for (const auto& stage_config : config.stage()) SetStageInputs(stage_config);\n}\n\nvoid DataLoader::AddStage(const StageConfig& stage_config) {\n  if (started_) {\n    throw std::runtime_error(\"Cannot add stages after DataLoader has started.\");\n  }\n\n  auto stage = CreateStage(stage_config);\n  if (!stage_config.has_name()) {\n    throw std::runtime_error(\"Stage configuration is missing name.\");\n  }\n\n  LOG(INFO) << \"Adding stage '\" << stage_config.name() << \"'.\";\n  stage_registry_.AddStage(stage_config.name(), std::move(stage));\n}\n\nvoid DataLoader::SetStageInputs(const StageConfig& stage_config) {\n  // Resolve input names to queue pointers.\n  std::vector<QueueBase*> input_queues;\n  input_queues.reserve(stage_config.input_size());\n  for (const auto& input_name : stage_config.input()) {\n    QueueBase* queue = stage_registry_.GetStageOutput(input_name);\n    if (!queue) {\n      throw std::runtime_error(absl::StrCat(\"Input stage '\", input_name,\n                                            \"' not found for stage '\",\n                                            stage_config.name(), \"'.\"));\n    }\n    input_queues.push_back(queue);\n  }\n\n  // Wire up inputs.\n  auto it = absl::c_find_if(stage_registry_.stages(), [&](const auto& p) {\n    return p.first == stage_config.name();\n  });\n  if (it == stage_registry_.stages().end()) {\n    throw std::runtime_error(absl::StrCat(\"Stage '\", stage_config.name(),\n                                          \"' not found in registry.\"));\n  }\n  it->second->SetInputs(absl::MakeSpan(input_queues));\n}\n\nvoid DataLoader::Start() {\n  if (started_) {\n    throw std::runtime_error(\"DataLoader has already been started.\");\n  }\n  for (auto& [name, stage] : stage_registry_.stages()) {\n    LOG(INFO) << \"Starting stage '\" << name << \"'.\";\n    stage->Start();\n  }\n\n  metrics_thread_ = std::jthread(\n      [this](std::stop_token stop_token) { MetricsThread(stop_token); });\n  started_ = true;\n  stopped_ = false;\n  LOG(INFO) << \"DataLoader started.\";\n}\n\nTensorTuple DataLoader::GetNext(std::string_view alias) {\n  Queue<TensorTuple>* q = GetOutputQueue(alias);\n  if (!q) {\n    std::string alias_list = absl::StrJoin(\n        outputs_, \", \",\n        [](std::string* out, const auto& p) { absl::StrAppend(out, p.first); });\n    throw std::runtime_error(absl::StrCat(\"Unknown DataLoader output: '\", alias,\n                                          \"'. Available outputs: [\", alias_list,\n                                          \"].\"));\n  }\n  return q->Get();\n}\n\nstd::optional<TensorTuple> DataLoader::MaybeGetNext(std::string_view alias) {\n  Queue<TensorTuple>* q = GetOutputQueue(alias);\n  if (!q) {\n    std::string alias_list = absl::StrJoin(\n        outputs_, \", \",\n        [](std::string* out, const auto& p) { absl::StrAppend(out, p.first); });\n    throw std::runtime_error(absl::StrCat(\"Unknown DataLoader output: '\", alias,\n                                          \"'. Available outputs: [\", alias_list,\n                                          \"].\"));\n  }\n  return q->MaybeGet();\n}\n\nvoid DataLoader::Stop() {\n  if (stopped_) return;\n\n  LOG(INFO) << \"Stopping DataLoader.\";\n\n  if (metrics_thread_.joinable()) {\n    metrics_thread_.request_stop();\n    metrics_thread_.join();\n  }\n\n  for (auto& [name, stage] : stage_registry_.stages()) {\n    LOG(INFO) << \"Stopping stage '\" << name << \"'.\";\n    stage->Stop();\n  }\n\n  stopped_ = true;\n  started_ = false;\n  LOG(INFO) << \"DataLoader stopped.\";\n}\n\nstd::pair<std::string, float> DataLoader::GetBucketMetrics(\n    int time_period, bool include_pending) const {\n  auto [metrics, duration] = metrics_aggregator_.GetBucketMetrics(\n      static_cast<TimePeriod>(time_period),\n      include_pending ? std::make_optional(std::chrono::steady_clock::now())\n                      : std::nullopt);\n  float duration_seconds = std::chrono::duration<float>(duration).count();\n  return {metrics.OutputAsString(), duration_seconds};\n}\n\nstd::pair<std::string, float> DataLoader::GetAggregateEndingNow(\n    float duration_seconds, bool include_pending) const {\n  std::chrono::nanoseconds duration_ns =\n      std::isinf(duration_seconds)\n          ? std::chrono::nanoseconds::max()\n          : std::chrono::nanoseconds(\n                static_cast<int64_t>(duration_seconds * 1e9));\n  auto [metrics, duration] = metrics_aggregator_.GetAggregateEndingNow(\n      duration_ns, include_pending\n                       ? std::make_optional(std::chrono::steady_clock::now())\n                       : std::nullopt);\n  float result_duration_seconds =\n      std::chrono::duration<float>(duration).count();\n  return {metrics.OutputAsString(), result_duration_seconds};\n}\n\nvoid DataLoader::MetricsThread(std::stop_token stop_token) {\n  while (!stop_token.stop_requested()) {\n    std::this_thread::sleep_for(std::chrono::milliseconds(100));\n    DataLoaderMetricsProto metrics;\n    for (auto& [name, stage] : stage_registry_.stages()) {\n      StageMetricProto stage_metric = stage->FlushMetrics();\n      stage_metric.set_name(name);\n      *metrics.add_stage_metrics() = std::move(stage_metric);\n    }\n\n    metrics_aggregator_.RecordMetrics(std::move(metrics));\n    metrics_aggregator_.Advance(std::chrono::steady_clock::now());\n  }\n  LOG(INFO) << \"Metrics thread stopping.\";\n}\n\nstd::vector<std::pair<std::string, StageControlResponse>>\nDataLoader::SendControlMessage(const StageControlRequest& request) {\n  std::vector<std::pair<std::string, StageControlResponse>> responses;\n  responses.reserve(stage_registry_.size());\n  for (auto& [name, stage] : stage_registry_.stages()) {\n    std::optional<StageControlResponse> response = stage->Control(request);\n    if (response.has_value()) {\n      responses.emplace_back(name, std::move(*response));\n    }\n  }\n  return responses;\n}\n\nvoid DataLoader::BuildOutputMapping(const DataLoaderConfig& config) {\n  outputs_.clear();\n  outputs_.reserve(config.output_size());\n\n  for (const auto& out_spec : config.output()) {\n    auto it = absl::c_find(out_spec, ':');\n    std::string alias = it == out_spec.end()\n                            ? std::string(\"\")\n                            : std::string(out_spec.begin(), it);\n    std::string stage_name = it == out_spec.end()\n                                 ? std::string(out_spec)\n                                 : std::string(it + 1, out_spec.end());\n\n    // Ensure alias is unique.\n    if (absl::c_find_if(outputs_, [&](const auto& p) {\n          return p.first == alias;\n        }) != outputs_.end()) {\n      throw std::runtime_error(\n          absl::StrCat(\"Duplicate output alias specified: \", alias));\n    }\n\n    Queue<TensorTuple>* queue =\n        stage_registry_.GetTypedStageOutput<TensorTuple>(stage_name);\n    if (queue == nullptr) {\n      throw std::runtime_error(\n          absl::StrCat(\"Output stage not found or wrong type: \", stage_name));\n    }\n    outputs_.emplace_back(std::move(alias), queue);\n  }\n}\n\nQueue<TensorTuple>* DataLoader::GetOutputQueue(std::string_view alias) const {\n  auto it = absl::c_find_if(outputs_,\n                            [&](const auto& p) { return p.first == alias; });\n  if (it == outputs_.end()) return nullptr;\n  return it->second;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/data_loader.h",
    "content": "#pragma once\n\n#include <memory>\n#include <string>\n#include <string_view>\n#include <thread>\n#include <utility>\n#include <vector>\n\n#include \"loader/stages/stage_factory.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/stage_control.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/exponential_aggregator.h\"\n#include \"utils/queue.h\"\n#include \"utils/tensor.h\"\n\nnamespace lczero {\nnamespace training {\n\nusing TensorDict =\n    std::vector<std::pair<std::string, std::unique_ptr<TensorBase>>>;\n\nclass DataLoader {\n public:\n  using MetricsAggregator = ExponentialAggregator<DataLoaderMetricsProto,\n                                                  TimePeriod::k250Milliseconds>;\n\n  explicit DataLoader(const std::string& serialized_data_loader_config);\n  ~DataLoader();\n\n  void Start();\n  TensorTuple GetNext(std::string_view alias);\n  std::optional<TensorTuple> MaybeGetNext(std::string_view alias);\n  void Stop();\n  std::pair<std::string, float> GetBucketMetrics(int time_period,\n                                                 bool include_pending) const;\n  std::pair<std::string, float> GetAggregateEndingNow(\n      float duration_seconds, bool include_pending) const;\n\n  void AddStages(const DataLoaderConfig& config);\n  void AddStages(const std::string& serialized_data_loader_config);\n\n  std::vector<std::pair<std::string, StageControlResponse>> SendControlMessage(\n      const StageControlRequest& request);\n\n private:\n  void AddStage(const StageConfig& stage_config);\n  void SetStageInputs(const StageConfig& stage_config);\n  static DataLoaderConfig ParseConfig(\n      const std::string& serialized_data_loader_config);\n  void MetricsThread(std::stop_token stop_token);\n  void BuildOutputMapping(const DataLoaderConfig& config);\n  Queue<TensorTuple>* GetOutputQueue(std::string_view alias) const;\n\n  StageRegistry stage_registry_;\n  std::vector<std::pair<std::string, Queue<TensorTuple>*>> outputs_;\n  MetricsAggregator metrics_aggregator_;\n  std::jthread metrics_thread_;\n  bool started_ = false;\n  bool stopped_ = false;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/data_loader_metrics.cc",
    "content": "// ABOUTME: Implementation of UpdateFrom functions for data loader metric\n// protobuf messages. ABOUTME: Handles aggregation of generic stage metrics.\n\n#include \"loader/data_loader_metrics.h\"\n\n#include <algorithm>\n\n#include \"absl/strings/string_view.h\"\n#include \"utils/metrics/statistics_metric.h\"\n\nnamespace lczero {\nnamespace training {\nnamespace {\n\ntemplate <typename ProtoT>\nProtoT* FindByName(std::vector<ProtoT>* entries, absl::string_view name) {\n  for (auto& entry : *entries) {\n    if (entry.name() == name) return &entry;\n  }\n  return nullptr;\n}\n\n}  // namespace\n\nvoid UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src) {\n  if (src.has_name()) dest.set_name(src.name());\n  dest.set_put_count(dest.put_count() + src.put_count());\n  dest.set_get_count(dest.get_count() + src.get_count());\n  dest.set_drop_count(dest.drop_count() + src.drop_count());\n  UpdateFrom(*dest.mutable_queue_fullness(), src.queue_fullness());\n  if (src.has_queue_capacity()) dest.set_queue_capacity(src.queue_capacity());\n}\n\nvoid UpdateFrom(CountMetricProto& dest, const CountMetricProto& src) {\n  if (src.has_name()) dest.set_name(src.name());\n  if (src.has_count()) dest.set_count(dest.count() + src.count());\n}\n\nvoid UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src) {\n  if (src.has_name()) dest.set_name(src.name());\n  if (src.has_value()) dest.set_value(src.value());\n  if (src.has_capacity()) dest.set_capacity(src.capacity());\n}\n\nvoid UpdateFrom(StageMetricProto& dest, const StageMetricProto& src) {\n  if (src.has_name()) dest.set_name(src.name());\n\n  for (const auto& load_metrics : src.load_metrics()) {\n    LoadMetricProto* dest_load =\n        load_metrics.has_name()\n            ? FindByName(dest.mutable_load_metrics(), load_metrics.name())\n            : nullptr;\n    if (dest_load == nullptr) {\n      dest_load = dest.add_load_metrics();\n    }\n    UpdateFrom(*dest_load, load_metrics);\n  }\n\n  for (const auto& queue_metrics : src.queue_metrics()) {\n    QueueMetricProto* dest_queue =\n        queue_metrics.has_name()\n            ? FindByName(dest.mutable_queue_metrics(), queue_metrics.name())\n            : nullptr;\n    if (dest_queue == nullptr) {\n      dest_queue = dest.add_queue_metrics();\n    }\n    UpdateFrom(*dest_queue, queue_metrics);\n  }\n\n  for (const auto& count_metrics : src.count_metrics()) {\n    CountMetricProto* dest_count =\n        count_metrics.has_name()\n            ? FindByName(dest.mutable_count_metrics(), count_metrics.name())\n            : nullptr;\n    if (dest_count == nullptr) {\n      dest_count = dest.add_count_metrics();\n    }\n    UpdateFrom(*dest_count, count_metrics);\n  }\n\n  for (const auto& gauge_metrics : src.gauge_metrics()) {\n    GaugeMetricProto* dest_gauge =\n        gauge_metrics.has_name()\n            ? FindByName(dest.mutable_gauge_metrics(), gauge_metrics.name())\n            : nullptr;\n    if (dest_gauge == nullptr) {\n      dest_gauge = dest.add_gauge_metrics();\n    }\n    UpdateFrom(*dest_gauge, gauge_metrics);\n  }\n\n  for (const auto& statistics_metrics : src.statistics_metrics()) {\n    StatisticsProtoDouble* dest_stats =\n        statistics_metrics.has_name()\n            ? FindByName(dest.mutable_statistics_metrics(),\n                         statistics_metrics.name())\n            : nullptr;\n    if (dest_stats == nullptr) {\n      dest_stats = dest.add_statistics_metrics();\n    }\n    UpdateFrom(*dest_stats, statistics_metrics);\n  }\n\n  if (src.has_last_chunk_key()) dest.set_last_chunk_key(src.last_chunk_key());\n  if (src.has_anchor()) dest.set_anchor(src.anchor());\n}\n\nvoid UpdateFrom(DataLoaderMetricsProto& dest,\n                const DataLoaderMetricsProto& src) {\n  for (const auto& stage_metrics : src.stage_metrics()) {\n    StageMetricProto* dest_stage =\n        stage_metrics.has_name()\n            ? FindByName(dest.mutable_stage_metrics(), stage_metrics.name())\n            : nullptr;\n    if (dest_stage == nullptr) {\n      dest_stage = dest.add_stage_metrics();\n    }\n\n    UpdateFrom(*dest_stage, stage_metrics);\n  }\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/data_loader_metrics.h",
    "content": "// ABOUTME: Header for UpdateFrom functions for data loader metric protobuf\n// messages. ABOUTME: Declares functions for aggregating FilePathProvider and\n// DataLoader metrics.\n\n#pragma once\n\n#include \"absl/strings/string_view.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n#include \"utils/metrics/statistics_metric.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nvoid UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src);\nvoid UpdateFrom(CountMetricProto& dest, const CountMetricProto& src);\nvoid UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src);\nvoid UpdateFrom(StageMetricProto& dest, const StageMetricProto& src);\nvoid UpdateFrom(DataLoaderMetricsProto& dest,\n                const DataLoaderMetricsProto& src);\n\ntemplate <typename T>\nQueueMetricProto MetricsFromQueue(absl::string_view name, Queue<T>& queue) {\n  QueueMetricProto result;\n  result.set_name(std::string(name));\n  result.set_put_count(queue.GetTotalPutCount(true));\n  result.set_get_count(queue.GetTotalGetCount(true));\n  result.set_drop_count(queue.GetTotalDropCount(true));\n  AddSample(*result.mutable_queue_fullness(), queue.Size());\n  result.set_queue_capacity(queue.Capacity());\n  return result;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/data_loader_test.cc",
    "content": "#include \"loader/data_loader.h\"\n\n#include <gtest/gtest.h>\n\n#include <stdexcept>\n\nnamespace lczero {\nnamespace training {\n\nTEST(DataLoaderTest, AllowsNoOutputsConfigured) {\n  DataLoaderConfig config;\n  auto* file_stage = config.add_stage();\n  file_stage->set_name(\"file_path_provider\");\n  file_stage->mutable_file_path_provider()->set_directory(\".\");\n\n  EXPECT_NO_THROW(DataLoader(config.OutputAsString()));\n}\n\nTEST(DataLoaderTest, ThrowsOnDuplicateStageName) {\n  DataLoaderConfig config;\n  auto* first_stage = config.add_stage();\n  first_stage->set_name(\"duplicate\");\n  first_stage->mutable_file_path_provider()->set_directory(\".\");\n\n  auto* second_stage = config.add_stage();\n  second_stage->set_name(\"duplicate\");\n  second_stage->mutable_file_path_provider()->set_directory(\".\");\n\n  EXPECT_THROW(DataLoader(config.OutputAsString()), std::runtime_error);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/frame_type.h",
    "content": "/*\n  This file is part of Leela Chess Zero.\n  Copyright (C) 2025 The LCZero Authors\n\n  Leela Chess is free software: you can redistribute it and/or modify\n  it under the terms of the GNU General Public License as published by\n  the Free Software Foundation, either version 3 of the License, or\n  (at your option) any later version.\n\n  Leela Chess is distributed in the hope that it will be useful,\n  but WITHOUT ANY WARRANTY; without even the implied warranty of\n  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n  GNU General Public License for more details.\n\n  You should have received a copy of the GNU General Public License\n  along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\n\n  Additional permission under GNU GPL version 3 section 7\n\n  If you modify this Program, or any covered work, by linking or\n  combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA\n  Toolkit and the NVIDIA CUDA Deep Neural Network library (or a\n  modified version of those libraries), containing parts covered by the\n  terms of the respective license agreement, the licensors of this\n  Program grant you additional permission to convey the resulting work.\n*/\n\n#pragma once\n\n#include \"trainingdata/trainingdata_v7.h\"\n\nnamespace lczero {\nnamespace training {\n\nusing FrameType = V7TrainingData;\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/loader_main.cpp",
    "content": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/str_cat.h>\n#include <absl/strings/str_format.h>\n#include <absl/time/clock.h>\n#include <absl/time/time.h>\n\n#include <chrono>\n#include <iostream>\n\n#include \"data_loader.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/printer.h\"\n\nABSL_FLAG(std::string, directory, \"/home/crem/tmp/2025-07/lczero-training/\",\n          \"Directory to watch for training data files\");\nABSL_FLAG(size_t, chunk_pool_size, 1000000, \"Size of the chunk shuffle buffer\");\nABSL_FLAG(size_t, reservoir_size_per_thread, 1000000,\n          \"Size of the reservoir for frame sampling per thread\");\n\nnamespace lczero {\nnamespace training {\n\nvoid Run() {\n  DataLoaderConfig config;\n\n  // Configure file path provider stage.\n  auto* file_stage = config.add_stage();\n  file_stage->set_name(\"file_path_provider\");\n  auto* file_path_provider = file_stage->mutable_file_path_provider();\n  file_path_provider->set_directory(absl::GetFlag(FLAGS_directory));\n\n  // Configure chunk source loader stage.\n  auto* chunk_loader_stage = config.add_stage();\n  chunk_loader_stage->set_name(\"chunk_source_loader\");\n  chunk_loader_stage->add_input(file_stage->name());\n  chunk_loader_stage->mutable_chunk_source_loader();\n\n  // Configure shuffling chunk pool stage.\n  auto* chunk_pool_stage = config.add_stage();\n  chunk_pool_stage->set_name(\"shuffling_chunk_pool\");\n  chunk_pool_stage->add_input(chunk_loader_stage->name());\n  auto* shuffling_chunk_pool = chunk_pool_stage->mutable_shuffling_chunk_pool();\n  shuffling_chunk_pool->set_chunk_pool_size(\n      absl::GetFlag(FLAGS_chunk_pool_size));\n\n  // Configure chunk unpacker stage.\n  auto* unpacker_stage = config.add_stage();\n  unpacker_stage->set_name(\"chunk_unpacker\");\n  unpacker_stage->add_input(chunk_pool_stage->name());\n  unpacker_stage->mutable_chunk_unpacker();\n\n  // Configure shuffling frame sampler stage.\n  auto* sampler_stage = config.add_stage();\n  sampler_stage->set_name(\"shuffling_frame_sampler\");\n  sampler_stage->add_input(unpacker_stage->name());\n  auto* shuffling_frame_sampler =\n      sampler_stage->mutable_shuffling_frame_sampler();\n  shuffling_frame_sampler->set_reservoir_size_per_thread(\n      absl::GetFlag(FLAGS_reservoir_size_per_thread));\n\n  // Configure tensor generator stage.\n  auto* tensor_stage = config.add_stage();\n  tensor_stage->set_name(\"tensor_generator\");\n  tensor_stage->add_input(sampler_stage->name());\n  tensor_stage->mutable_tensor_generator();\n\n  // Serialize config and create loader\n  config.add_output(\"tensor_generator\");\n  std::string config_string = config.OutputAsString();\n  DataLoader loader(config_string);\n\n  return;\n\n  std::atomic<size_t> batch_count = 0;\n  auto start_time = absl::Now();\n\n  // Start logging thread\n  std::atomic<bool> should_stop{false};\n  std::thread logging_thread([&]() {\n    while (!should_stop.load()) {\n      std::this_thread::sleep_for(std::chrono::seconds(1));\n\n      auto current_time = absl::Now();\n      auto total_elapsed = current_time - start_time;\n      double rate = batch_count / absl::ToDoubleSeconds(total_elapsed);\n\n      auto [stats_string, duration] =\n          loader.GetBucketMetrics(0, false);  // k1Second = 0\n      DataLoaderMetricsProto metrics;\n      metrics.ParseFromString(stats_string);\n      std::string metrics_json = metrics.OutputAsJson();\n\n      LOG(INFO) << absl::StrCat(\"Processed \", batch_count.load(),\n                                \" batches in \",\n                                absl::ToDoubleSeconds(total_elapsed),\n                                \"s. Rate: \", absl::StrFormat(\"%.2f\", rate),\n                                \" batches/sec. \", \"Metrics: \", metrics_json);\n    }\n  });\n\n  while (true) {\n    TensorTuple batch = loader.GetNext(\"train\");\n    ++batch_count;\n  }\n}\n\n}  // namespace training\n}  // namespace lczero\n\nint main(int argc, char* argv[]) {\n  absl::ParseCommandLine(argc, argv);\n  absl::InitializeLog();\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n  lczero::training::Run();\n  return 0;\n}\n"
  },
  {
    "path": "csrc/loader/pybind_module.cc",
    "content": "// ABOUTME: PyBind11 binding module exposing C++ DataLoader to Python.\n// ABOUTME: Handles configuration conversion and tensor memory management for\n// numpy arrays.\n\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <pybind11/numpy.h>\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <pybind11/stl/filesystem.h>\n#include <pybind11/stl_bind.h>\n\n#include <stdexcept>\n#include <string>\n\n#include \"loader/data_loader.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/chunk_unpacker.h\"\n#include \"loader/stages/file_path_provider.h\"\n#include \"loader/stages/shuffling_chunk_pool.h\"\n#include \"loader/stages/shuffling_frame_sampler.h\"\n#include \"loader/stages/tensor_generator.h\"\n#include \"utils/tensor.h\"\n\nnamespace py = pybind11;\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\nstd::string SerializePyProto(const py::handle& obj, const char* expected_type) {\n  if (py::isinstance<py::bytes>(obj)) {\n    return obj.cast<std::string>();\n  }\n  if (py::hasattr(obj, \"SerializeToString\")) {\n    py::object bytes_obj = obj.attr(\"SerializeToString\")();\n    return bytes_obj.cast<py::bytes>().cast<std::string>();\n  }\n  throw std::invalid_argument(std::string(\"Expected \") + expected_type +\n                              \" protobuf message or bytes.\");\n}\n\ntemplate <typename ProtoT>\nProtoT ParsePyProto(const py::handle& obj, const char* expected_type) {\n  ProtoT proto;\n  proto.ParseFromString(SerializePyProto(obj, expected_type));\n  return proto;\n}\n\npy::object MakePythonProto(const char* module_name, const char* message_name,\n                           const std::string& serialized) {\n  py::object message_cls = py::module::import(module_name).attr(message_name);\n  py::object message_obj = message_cls();\n  message_obj.attr(\"ParseFromString\")(py::bytes(serialized));\n  return message_obj;\n}\n\n}  // namespace\n\n// Helper function to convert TensorBase to numpy array using buffer protocol.\npy::array tensor_to_numpy(std::unique_ptr<TensorBase> tensor) {\n  // Extract raw pointer and release ownership from unique_ptr.\n  TensorBase* raw_tensor = tensor.release();\n\n  // Create numpy array with take_ownership policy.\n  // This transfers memory ownership to Python/numpy.\n  return py::array(\n      py::dtype(raw_tensor->py_format()), raw_tensor->shape(),\n      raw_tensor->strides(), raw_tensor->data(),\n      py::cast(raw_tensor, py::return_value_policy::take_ownership));\n}\n\n// Convert TensorTuple to tuple of numpy arrays.\npy::tuple tensor_tuple_to_numpy_tuple(TensorTuple tensor_tuple) {\n  py::tuple result(tensor_tuple.size());\n  for (size_t i = 0; i < tensor_tuple.size(); ++i) {\n    result[i] = tensor_to_numpy(std::move(tensor_tuple[i]));\n  }\n  return result;\n}\n\nPYBIND11_MODULE(_lczero_training, m) {\n  absl::InitializeLog();\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n  m.doc() = \"Leela Chess Zero training data loader\";\n\n  // Configuration is now handled via protobuf serialized strings\n\n  // Expose the main DataLoader class.\n  py::class_<DataLoader>(m, \"DataLoader\")\n      .def(py::init([](py::object config) {\n             std::string config_string =\n                 SerializePyProto(config, \"DataLoaderConfig\");\n             py::gil_scoped_release release;\n             return new DataLoader(config_string);\n           }),\n           py::arg(\"config\"),\n           \"Create DataLoader from DataLoaderConfig proto or bytes.\")\n      .def(\n          \"add_stages\",\n          [](DataLoader& self, py::object config) {\n            std::string config_string =\n                SerializePyProto(config, \"DataLoaderConfig\");\n            py::gil_scoped_release release;\n            self.AddStages(config_string);\n          },\n          py::arg(\"config\"),\n          \"Append stages from DataLoaderConfig proto or bytes.\")\n      .def(\n          \"send_control_message\",\n          [](DataLoader& self, py::object request) {\n            StageControlRequest control_request =\n                ParsePyProto<StageControlRequest>(request,\n                                                  \"StageControlRequest\");\n            auto responses = [&]() {\n              py::gil_scoped_release release;\n              return self.SendControlMessage(control_request);\n            }();\n            py::list result;\n            for (const auto& [stage_name, response] : responses) {\n              py::object response_obj = MakePythonProto(\n                  \"proto.stage_control_pb2\", \"StageControlResponse\",\n                  response.OutputAsString());\n              result.append(py::make_tuple(stage_name, response_obj));\n            }\n            return result;\n          },\n          py::arg(\"request\"),\n          \"Send StageControlRequest to stages and return (stage_name, \"\n          \"StageControlResponse) tuples.\")\n      .def(\n          \"get_next\",\n          [](DataLoader& self, const std::string& alias) {\n            return tensor_tuple_to_numpy_tuple([&] {\n              py::gil_scoped_release release;\n              return self.GetNext(alias);\n            }());\n          },\n          py::arg(\"alias\") = \"\",\n          \"Get next batch for the given output alias (default empty) as a \"\n          \"tuple of numpy arrays\")\n      .def(\n          \"maybe_get_next\",\n          [](DataLoader& self,\n             const std::string& alias) -> std::optional<py::tuple> {\n            auto result = [&] {\n              py::gil_scoped_release release;\n              return self.MaybeGetNext(alias);\n            }();\n            if (result.has_value()) {\n              return tensor_tuple_to_numpy_tuple(std::move(*result));\n            }\n            return std::nullopt;\n          },\n          py::arg(\"alias\") = \"\",\n          \"Non-blocking get next batch for the given output alias (default \"\n          \"empty). Returns tuple of numpy arrays or None if no data available\")\n      .def(\n          \"get_bucket_metrics\",\n          [](const DataLoader& self, int time_period, bool include_pending) {\n            auto [metrics, duration] = [&] {\n              py::gil_scoped_release release;\n              return self.GetBucketMetrics(time_period, include_pending);\n            }();\n            return py::make_tuple(py::bytes(metrics), duration);\n          },\n          \"Get serialized metrics for bucket and duration as (bytes, float)\")\n      .def(\n          \"get_aggregate_ending_now\",\n          [](const DataLoader& self, float duration_seconds,\n             bool include_pending) {\n            auto [metrics, duration] = [&] {\n              py::gil_scoped_release release;\n              return self.GetAggregateEndingNow(duration_seconds,\n                                                include_pending);\n            }();\n            return py::make_tuple(py::bytes(metrics), duration);\n          },\n          \"Get serialized metrics for aggregate duration and actual duration \"\n          \"as (bytes, float)\")\n      .def(\"start\", &DataLoader::Start, \"Start the data loader processing\")\n      .def(\"stop\", &DataLoader::Stop, \"Stop the data loader\");\n\n  // Expose TensorBase for potential advanced usage.\n  py::class_<TensorBase>(m, \"TensorBase\")\n      .def(\"shape\", &TensorBase::shape, py::return_value_policy::reference)\n      .def(\"strides\", &TensorBase::strides, py::return_value_policy::reference)\n      .def(\"element_size\", &TensorBase::element_size)\n      .def(\"py_format\", &TensorBase::py_format);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_rescorer.cc",
    "content": "#include \"loader/stages/chunk_rescorer.h\"\n\n#include <stdexcept>\n#include <utility>\n\n#include \"absl/base/call_once.h\"\n#include \"absl/log/log.h\"\n#include \"chess/board.h\"\n#include \"loader/data_loader_metrics.h\"\n\nnamespace lczero {\nnamespace training {\nnamespace {\n\nvoid V6ToV7(std::span<FrameType> data, float theta = 5.0f / 6.0f) {\n  if (data.empty()) return;\n\n  const float beta = 1.0f - theta;\n\n  float st_q = 0.0f;\n  float st_d = 0.0f;\n\n  // Iterate backwards to calculate EMA and lookaheads in one go.\n  for (size_t i = data.size(); i-- > 0;) {\n    FrameType& item = data[i];\n\n    // For Q, we operate in an alternating sign domain.\n    const float sign = (i % 2 != 0) ? -1.0f : 1.0f;\n    const float cur_q = item.root_q * sign;\n\n    if (i == data.size() - 1) {\n      st_q = cur_q;\n      st_d = item.root_d;\n    } else {\n      st_q = theta * st_q + beta * cur_q;\n      st_d = theta * st_d + beta * item.root_d;\n    }\n\n    item.version = 7;\n    item.q_st = st_q * sign;\n    item.d_st = std::max(st_d, 0.0f);\n\n    // Handle lookahead indices safely.\n    auto get_idx = [&](size_t offset) {\n      return (i + offset < data.size()) ? data[i + offset].played_idx : 65535;\n    };\n    item.opp_played_idx = get_idx(1);\n    item.next_played_idx = get_idx(2);\n  }\n}\n\n}  // namespace\n\nChunkRescorer::ChunkRescorer(const ChunkRescorerConfig& config)\n    : SingleInputStage<ChunkRescorerConfig, InputType>(config),\n      SingleOutputStage<OutputType>(config.output()),\n      syzygy_paths_(config.syzygy_paths()),\n      dist_temp_(config.dist_temp()),\n      dist_offset_(config.dist_offset()),\n      dtz_boost_(config.dtz_boost()),\n      new_input_format_(config.new_input_format()),\n      thread_pool_(config.threads(), ThreadPoolOptions{}),\n      st_q_theta_(config.st_q_theta()) {\n  static absl::once_flag bitboards_initialized_flag;\n  absl::call_once(bitboards_initialized_flag, InitializeMagicBitboards);\n\n  if (config.has_deblunder_threshold() && config.has_deblunder_width()) {\n    RescorerDeblunderSetup(config.deblunder_threshold(),\n                           config.deblunder_width());\n  }\n\n  if (config.has_gaviota_paths()) {\n    RescorerGaviotaSetup(std::string(config.gaviota_paths()));\n  }\n\n  LOG(INFO) << \"Initializing ChunkRescorer with \" << config.threads()\n            << \" worker thread(s)\";\n\n  thread_contexts_.reserve(config.threads());\n  for (size_t i = 0; i < config.threads(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n}\n\nChunkRescorer::~ChunkRescorer() { Stop(); }\n\nvoid ChunkRescorer::InitializeTablebase() {\n  if (tablebase_initialized_) {\n    return;\n  }\n\n  LOG(INFO) << \"ChunkRescorer initializing Syzygy tablebase with paths '\"\n            << syzygy_paths_ << \"'.\";\n  tablebase_initialized_ = tablebase_.init(syzygy_paths_);\n  if (tablebase_initialized_) {\n    LOG(INFO) << \"ChunkRescorer Syzygy max cardinality: \"\n              << tablebase_.max_cardinality();\n  } else {\n    LOG(WARNING) << \"ChunkRescorer failed to initialize Syzygy tablebase; \"\n                    \"rescoring will continue without tablebase lookups.\";\n  }\n}\n\nvoid ChunkRescorer::Start() {\n  LOG(INFO) << \"Starting ChunkRescorer worker threads.\";\n  InitializeTablebase();\n  for (size_t i = 0; i < thread_contexts_.size(); ++i) {\n    thread_pool_.Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, thread_contexts_[i].get());\n    });\n  }\n}\n\nvoid ChunkRescorer::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n\n  LOG(INFO) << \"Stopping ChunkRescorer.\";\n  thread_pool_.Shutdown();\n  output_queue()->Close();\n  LOG(INFO) << \"ChunkRescorer stopped.\";\n}\n\nvoid ChunkRescorer::Worker(std::stop_token stop_token, ThreadContext* context) {\n  auto producer = output_queue()->CreateProducer();\n\n  try {\n    while (true) {\n      TrainingChunk chunk = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return input_queue()->Get(stop_token);\n      }();\n\n      try {\n        chunk.frames = RescoreTrainingData<FrameType>(\n            chunk.frames, &tablebase_, dist_temp_, dist_offset_, dtz_boost_,\n            new_input_format_);\n        if (chunk.frames.at(0).version == 6) V6ToV7(chunk.frames, st_q_theta_);\n        LoadMetricPauser pauser(context->load_metric_updater);\n        producer.Put(std::move(chunk), stop_token);\n      } catch (const std::exception& exception) {\n        failed_rescores_.fetch_add(1, std::memory_order_acq_rel);\n        LOG(ERROR) << \"ChunkRescorer failed to rescore chunk: \"\n                   << exception.what() << \"; sort_key=\" << chunk.sort_key\n                   << \"; index_within_sort_key=\" << chunk.index_within_sort_key\n                   << \"; global_index=\" << chunk.global_index\n                   << \"; use_count=\" << chunk.use_count\n                   << \"; frame_count=\" << chunk.frames.size();\n        continue;\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"ChunkRescorer worker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"ChunkRescorer worker stopping, request cancelled.\";\n  }\n}\n\nStageMetricProto ChunkRescorer::FlushMetrics() {\n  StageMetricProto stage_metric;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(aggregated_load);\n\n  auto* failed = stage_metric.add_count_metrics();\n  failed->set_name(\"failed_rescores\");\n  failed->set_count(failed_rescores_.exchange(0, std::memory_order_acq_rel));\n\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *output_queue());\n  return stage_metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_rescorer.h",
    "content": "// ABOUTME: Stage that rescales training chunks using Syzygy tablebases.\n// ABOUTME: Adjusts frame metadata by invoking the classic LCZero rescorer.\n#pragma once\n\n#include <atomic>\n#include <functional>\n#include <memory>\n#include <stop_token>\n#include <string>\n#include <vector>\n\n#include \"libs/lc0/src/syzygy/syzygy.h\"\n#include \"libs/lc0/src/trainingdata/rescorer.h\"\n#include \"loader/frame_type.h\"\n#include \"loader/stages/stage.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Stage that takes TrainingChunk objects, applies tablebase-based rescoring and\n// forwards the updated chunks downstream.\nclass ChunkRescorer\n    : public SingleInputStage<ChunkRescorerConfig, TrainingChunk>,\n      public SingleOutputStage<TrainingChunk> {\n public:\n  using InputType = TrainingChunk;\n  using OutputType = TrainingChunk;\n\n  explicit ChunkRescorer(const ChunkRescorerConfig& config);\n  ~ChunkRescorer() override;\n\n  void Start() override;\n  void Stop() override;\n  StageMetricProto FlushMetrics() override;\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, ThreadContext* context);\n  void InitializeTablebase();\n\n  SyzygyTablebase tablebase_;\n  bool tablebase_initialized_ = false;\n  std::string syzygy_paths_;\n  float dist_temp_;\n  float dist_offset_;\n  float dtz_boost_;\n  int new_input_format_;\n\n  ThreadPool thread_pool_;\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  std::atomic<uint64_t> failed_rescores_{0};\n  float st_q_theta_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_rescorer_test.cc",
    "content": "#include \"loader/stages/chunk_rescorer.h\"\n\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"gtest/gtest.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nclass ChunkRescorerTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    input_queue_ = std::make_unique<Queue<TrainingChunk>>(10);\n    config_.set_threads(1);\n    config_.mutable_output()->set_queue_capacity(10);\n    config_.set_syzygy_paths(\"\");\n    config_.set_dist_temp(0.75f);\n    config_.set_dist_offset(0.1f);\n    config_.set_dtz_boost(0.2f);\n    config_.set_new_input_format(-1);\n  }\n\n  TrainingChunk MakeChunk(std::vector<FrameType> frames,\n                          std::string sort_key = \"alpha\", size_t index = 3,\n                          uint32_t use = 7) {\n    TrainingChunk chunk;\n    chunk.sort_key = std::move(sort_key);\n    chunk.index_within_sort_key = index;\n    chunk.use_count = use;\n    chunk.frames = std::move(frames);\n    return chunk;\n  }\n\n  std::unique_ptr<Queue<TrainingChunk>> input_queue_;\n  ChunkRescorerConfig config_;\n};\n\nTEST_F(ChunkRescorerTest, HandlesInputQueueClosure) {\n  ChunkRescorer rescorer(config_);\n  rescorer.SetInputs({input_queue_.get()});\n  rescorer.Start();\n\n  input_queue_->Close();\n\n  EXPECT_THROW(rescorer.output_queue()->Get(), QueueClosedException);\n\n  rescorer.Stop();\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_loader.cc",
    "content": "#include \"loader/stages/chunk_source_loader.h\"\n\n#include <filesystem>\n#include <utility>\n\n#include \"absl/log/log.h\"\n#include \"loader/chunk_source/rawfile_chunk_source.h\"\n#include \"loader/chunk_source/tar_chunk_source.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nstd::unique_ptr<ChunkSource> CreateChunkSourceFromFile(\n    const std::filesystem::path& filepath,\n    ChunkSourceLoaderConfig::FrameFormat frame_format) {\n  auto extension = filepath.extension();\n  try {\n    if (extension == \".gz\") {\n      return std::make_unique<RawFileChunkSource>(filepath, frame_format);\n    }\n    if (extension == \".tar\") {\n      return std::make_unique<TarChunkSource>(filepath, frame_format);\n    }\n  } catch (const std::exception& e) {\n    LOG(ERROR) << \"Failed to create chunk source for \" << filepath << \": \"\n               << e.what();\n    return nullptr;\n  }\n  return nullptr;\n}\n\nChunkSourceLoader::ChunkSourceLoader(const ChunkSourceLoaderConfig& config)\n    : SingleInputStage<ChunkSourceLoaderConfig, InputType>(config),\n      SingleOutputStage<OutputType>(config.output()),\n      thread_pool_(config.threads(), ThreadPoolOptions{}),\n      frame_format_(config.frame_format()) {\n  LOG(INFO) << \"Initializing ChunkSourceLoader with \" << config.threads()\n            << \" worker threads\";\n\n  // Initialize thread contexts but don't start worker threads yet.\n  thread_contexts_.reserve(config.threads());\n  for (size_t i = 0; i < config.threads(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n}\n\nChunkSourceLoader::~ChunkSourceLoader() { Stop(); }\n\nvoid ChunkSourceLoader::Start() {\n  LOG(INFO) << \"Starting ChunkSourceLoader worker threads.\";\n  for (size_t i = 0; i < thread_contexts_.size(); ++i) {\n    thread_pool_.Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, thread_contexts_[i].get());\n    });\n  }\n}\n\nvoid ChunkSourceLoader::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n\n  LOG(INFO) << \"Stopping ChunkSourceLoader.\";\n  thread_pool_.Shutdown();\n  output_queue()->Close();\n  LOG(INFO) << \"ChunkSourceLoader stopped.\";\n}\n\nvoid ChunkSourceLoader::Worker(std::stop_token stop_token,\n                               ThreadContext* context) {\n  auto producer = output_queue()->CreateProducer();\n  LOG(INFO) << \"ChunkSourceLoader worker@\" << static_cast<const void*>(context)\n            << \" started.\";\n\n  try {\n    while (true) {\n      auto file = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return input_queue()->Get(stop_token);\n      }();\n\n      if (file.message_type ==\n          FilePathProvider::MessageType::kInitialScanComplete) {\n        LOG(INFO)\n            << \"ChunkSourceLoader received initial scan completion marker.\";\n\n        bool should_forward;\n        {\n          absl::MutexLock lock(&phase_mutex_);\n          sentinel_received_ = true;\n          should_forward = (pre_sentinel_work_count_ == 0);\n        }\n\n        if (should_forward) {\n          LOG(INFO) << \"ChunkSourceLoader forwarding initial scan completion \"\n                       \"marker.\";\n          producer.Put({.source = nullptr, .message_type = file.message_type},\n                       stop_token);\n        }\n        continue;\n      }\n\n      // Track pre-sentinel work.\n      bool is_pre_sentinel;\n      {\n        absl::MutexLock lock(&phase_mutex_);\n        is_pre_sentinel = !sentinel_received_;\n        if (is_pre_sentinel) pre_sentinel_work_count_++;\n      }\n\n      // Create ChunkSource from the file.\n      LOG_EVERY_N(INFO, 1000)\n          << \"ChunkSourceLoader preparing chunk source for \" << file.filepath;\n      auto source = CreateChunkSourceFromFile(file.filepath, frame_format_);\n      if (source) {\n        {\n          absl::MutexLock lock(&last_chunk_key_mutex_);\n          last_chunk_key_ = source->GetChunkSortKey();\n        }\n        ChunkSourceWithPhase output{.source = std::move(source),\n                                    .message_type = file.message_type};\n        LoadMetricPauser pauser(context->load_metric_updater);\n        producer.Put(std::move(output), stop_token);\n      } else {\n        LOG_EVERY_N(INFO, 100)\n            << \"ChunkSourceLoader skipping unsupported file: \" << file.filepath;\n        skipped_files_count_++;\n      }\n\n      // Complete pre-sentinel work tracking.\n      if (is_pre_sentinel) {\n        absl::MutexLock lock(&phase_mutex_);\n        if (--pre_sentinel_work_count_ == 0 && sentinel_received_) {\n          LOG(INFO) << \"ChunkSourceLoader forwarding initial scan completion \"\n                       \"marker after all pre-sentinel work completed.\";\n          producer.Put(\n              {.source = nullptr,\n               .message_type =\n                   FilePathProvider::MessageType::kInitialScanComplete},\n              stop_token);\n        }\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"ChunkSourceLoader worker@\"\n              << static_cast<const void*>(context)\n              << \" stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"ChunkSourceLoader worker@\"\n              << static_cast<const void*>(context)\n              << \" stopping, request cancelled.\";\n  } catch (const std::exception& e) {\n    LOG(ERROR) << \"ChunkSourceLoader worker@\"\n               << static_cast<const void*>(context)\n               << \" exiting due to exception: \" << e.what();\n    throw;\n  }\n\n  LOG(INFO) << \"ChunkSourceLoader worker@\" << static_cast<const void*>(context)\n            << \" exiting loop.\";\n}\n\nStageMetricProto ChunkSourceLoader::FlushMetrics() {\n  StageMetricProto stage_metric;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(aggregated_load);\n\n  auto* skipped_metric = stage_metric.add_count_metrics();\n  skipped_metric->set_name(\"skipped_files\");\n  skipped_metric->set_count(skipped_files_count_.exchange(0));\n\n  // Get the last chunk key.\n  {\n    absl::MutexLock lock(&last_chunk_key_mutex_);\n    if (!last_chunk_key_.empty()) {\n      stage_metric.set_last_chunk_key(last_chunk_key_);\n    }\n  }\n\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *output_queue());\n  return stage_metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_loader.h",
    "content": "#pragma once\n\n#include <atomic>\n#include <filesystem>\n#include <memory>\n#include <stop_token>\n\n#include \"absl/base/thread_annotations.h\"\n#include \"absl/synchronization/mutex.h\"\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/stages/file_path_provider.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Creates a ChunkSource based on file extension. Returns RawFileChunkSource for\n// .gz files, TarChunkSource for .tar files, or nullptr for unsupported types.\nstd::unique_ptr<ChunkSource> CreateChunkSourceFromFile(\n    const std::filesystem::path& filepath,\n    ChunkSourceLoaderConfig::FrameFormat frame_format);\n\nstruct ChunkSourceWithPhase {\n  std::unique_ptr<ChunkSource> source;\n  FilePathProvider::MessageType message_type;\n};\n\n// Worker pool that converts FilePathProvider output to ChunkSource objects.\n// Takes FilePathProvider::File as input and outputs ChunkSourceWithPhase.\nclass ChunkSourceLoader\n    : public SingleInputStage<ChunkSourceLoaderConfig, FilePathProvider::File>,\n      public SingleOutputStage<ChunkSourceWithPhase> {\n public:\n  using InputType = FilePathProvider::File;\n  using OutputType = ChunkSourceWithPhase;\n\n  explicit ChunkSourceLoader(const ChunkSourceLoaderConfig& config);\n  ~ChunkSourceLoader();\n\n  void Start() override;\n  void Stop() override;\n\n  StageMetricProto FlushMetrics() override;\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, ThreadContext* context);\n  ThreadPool thread_pool_;\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  std::atomic<uint64_t> skipped_files_count_{0};\n  absl::Mutex last_chunk_key_mutex_;\n  std::string last_chunk_key_;\n  ChunkSourceLoaderConfig::FrameFormat frame_format_;\n\n  // Synchronization for sentinel barrier.\n  absl::Mutex phase_mutex_;\n  int pre_sentinel_work_count_ ABSL_GUARDED_BY(phase_mutex_) = 0;\n  bool sentinel_received_ ABSL_GUARDED_BY(phase_mutex_) = false;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_loader_test.cc",
    "content": "#include \"loader/stages/chunk_source_loader.h\"\n\n#include <gtest/gtest.h>\n\n#include <filesystem>\n\n#include \"loader/stages/file_path_provider.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nTEST(ChunkSourceLoaderTest, ProcessesFiles) {\n  Queue<FilePathProvider::File> input_queue(10);\n  ChunkSourceLoaderConfig config;\n  config.set_threads(1);\n  config.mutable_output()->set_queue_capacity(10);\n  ChunkSourceLoader feed(config);\n  feed.SetInputs({&input_queue});\n  feed.Start();\n\n  {\n    auto producer = input_queue.CreateProducer();\n    // Add a file with unsupported extension (should not create ChunkSource)\n    producer.Put(FilePathProvider::File{\n        .filepath =\n            std::filesystem::path(\"/test.txt\"),  // unsupported extension\n        .message_type = FilePathProvider::MessageType::kFile});\n  }  // Producer destroyed here, closing input queue\n\n  // Try to get output - there should be no valid ChunkSources for unsupported\n  // files\n  try {\n    while (true) {\n      auto output = feed.output_queue()->Get();\n      // If we get output, it means a ChunkSource was created, which shouldn't\n      // happen for unsupported files\n      FAIL() << \"Expected no output for unsupported file extension\";\n    }\n  } catch (const QueueClosedException&) {\n    // Expected: queue should be closed when input is done and no output\n    // produced\n    SUCCEED();\n  }\n}\n\nTEST(ChunkSourceLoaderTest, HandlesPhases) {\n  Queue<FilePathProvider::File> input_queue(10);\n  ChunkSourceLoaderConfig config;\n  config.set_threads(1);\n  config.mutable_output()->set_queue_capacity(10);\n  ChunkSourceLoader feed(config);\n  feed.SetInputs({&input_queue});\n  feed.Start();\n\n  {\n    auto producer = input_queue.CreateProducer();\n    // Test different phases - all should be passed through even if no\n    // ChunkSource is created\n    producer.Put(FilePathProvider::File{\n        .filepath = std::filesystem::path(\"/test1.gz\"),\n        .message_type = FilePathProvider::MessageType::kFile});\n\n    producer.Put(FilePathProvider::File{\n        .filepath = std::filesystem::path(\"/test2.gz\"),\n        .message_type = FilePathProvider::MessageType::kFile});\n  }  // Producer destroyed here, closing input queue\n\n  // Queue should eventually close when input is done\n  try {\n    while (true) {\n      feed.output_queue()->Get();\n    }\n  } catch (const QueueClosedException&) {\n    SUCCEED();\n  }\n}\n\nTEST(ChunkSourceLoaderTest, PassesThroughInitialScanComplete) {\n  Queue<FilePathProvider::File> input_queue(10);\n  ChunkSourceLoaderConfig config;\n  config.set_threads(1);\n  config.mutable_output()->set_queue_capacity(10);\n  ChunkSourceLoader feed(config);\n  feed.SetInputs({&input_queue});\n  feed.Start();\n\n  {\n    auto producer = input_queue.CreateProducer();\n    producer.Put(FilePathProvider::File{\n        .filepath = std::filesystem::path(\"\"),\n        .message_type = FilePathProvider::MessageType::kInitialScanComplete});\n  }  // Producer destroyed here, closing input queue\n\n  // Should get kInitialScanComplete in output with null ChunkSource\n  auto output = feed.output_queue()->Get();\n  EXPECT_EQ(output.message_type,\n            FilePathProvider::MessageType::kInitialScanComplete);\n  EXPECT_EQ(output.source, nullptr);\n\n  // Queue should be closed after the single message\n  try {\n    feed.output_queue()->Get();\n    FAIL() << \"Expected queue to be closed\";\n  } catch (const QueueClosedException&) {\n    SUCCEED();\n  }\n}\n\nTEST(ChunkSourceLoaderTest, SentinelBarrierWithMultipleThreads) {\n  Queue<FilePathProvider::File> input_queue(100);\n  ChunkSourceLoaderConfig config;\n  config.set_threads(4);\n  config.mutable_output()->set_queue_capacity(100);\n  ChunkSourceLoader feed(config);\n  feed.SetInputs({&input_queue});\n  feed.Start();\n\n  {\n    auto producer = input_queue.CreateProducer();\n    // Add files that will be processed before sentinel.\n    for (int i = 0; i < 20; ++i) {\n      producer.Put(FilePathProvider::File{\n          .filepath =\n              std::filesystem::path(\"/test\" + std::to_string(i) + \".txt\"),\n          .message_type = FilePathProvider::MessageType::kFile});\n    }\n    // Add sentinel.\n    producer.Put(FilePathProvider::File{\n        .filepath = std::filesystem::path(\"\"),\n        .message_type = FilePathProvider::MessageType::kInitialScanComplete});\n    // Add files that arrive after sentinel.\n    for (int i = 20; i < 30; ++i) {\n      producer.Put(FilePathProvider::File{\n          .filepath =\n              std::filesystem::path(\"/test\" + std::to_string(i) + \".txt\"),\n          .message_type = FilePathProvider::MessageType::kFile});\n    }\n  }  // Producer destroyed here, closing input queue\n\n  // Read all outputs and verify sentinel comes after all pre-sentinel files.\n  int files_before_sentinel = 0;\n  int files_after_sentinel = 0;\n  bool sentinel_seen = false;\n\n  try {\n    while (true) {\n      auto output = feed.output_queue()->Get();\n      if (output.message_type ==\n          FilePathProvider::MessageType::kInitialScanComplete) {\n        EXPECT_FALSE(sentinel_seen) << \"Sentinel should appear exactly once\";\n        sentinel_seen = true;\n      } else {\n        if (sentinel_seen) {\n          files_after_sentinel++;\n        } else {\n          files_before_sentinel++;\n        }\n      }\n    }\n  } catch (const QueueClosedException&) {\n  }\n\n  // Verify sentinel was seen.\n  EXPECT_TRUE(sentinel_seen);\n  // All 20 pre-sentinel files should be before sentinel (unsupported, so 0).\n  EXPECT_EQ(files_before_sentinel, 0);\n  // All 10 post-sentinel files should be after sentinel (unsupported, so 0).\n  EXPECT_EQ(files_after_sentinel, 0);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_splitter.cc",
    "content": "#include \"loader/stages/chunk_source_splitter.h\"\n\n#include <algorithm>\n#include <numeric>\n#include <utility>\n\n#include \"absl/algorithm/container.h\"\n#include \"absl/strings/str_cat.h\"\n#include \"loader/data_loader_metrics.h\"\n\nnamespace lczero {\nnamespace training {\n\nChunkSourceSplitter::ChunkSourceSplitter(\n    const ChunkSourceSplitterConfig& config)\n    : SingleInputStage<ChunkSourceSplitterConfig, InputType>(config) {\n  if (config.output().empty()) {\n    throw std::runtime_error(\"ChunkSourceSplitter requires at least 1 output.\");\n  }\n\n  // Validate parallel arrays have same size.\n  if (config.output_size() != config.weight_size()) {\n    throw std::runtime_error(absl::StrCat(\n        \"ChunkSourceSplitter output and weight arrays must have same size: \",\n        config.output_size(), \" vs \", config.weight_size()));\n  }\n\n  // Create output queues from parallel arrays.\n  outputs_.reserve(config.output_size());\n  for (size_t i = 0; i < static_cast<size_t>(config.output_size()); ++i) {\n    const auto& queue_cfg = config.output(static_cast<int>(i));\n    const uint64_t weight = i < static_cast<size_t>(config.weight_size())\n                                ? config.weight(static_cast<int>(i))\n                                : 1;\n\n    if (absl::c_any_of(outputs_, [&](const auto& existing_out) {\n          return existing_out->name == queue_cfg.name();\n        })) {\n      throw std::runtime_error(std::string(absl::StrCat(\n          \"Duplicate output name in ChunkSourceSplitter: \", queue_cfg.name())));\n    }\n\n    auto* out = outputs_\n                    .emplace_back(std::make_unique<Output>(\n                        queue_cfg.name(), weight, queue_cfg.queue_capacity(),\n                        ToOverflowBehavior(queue_cfg.overflow_behavior())))\n                    .get();\n    LOG(INFO) << \"ChunkSourceSplitter configured output '\" << out->name\n              << \"' weight=\" << out->weight\n              << \" capacity=\" << queue_cfg.queue_capacity();\n  }\n\n  // Precompute cumulative weights for fast assignment.\n  cumulative_.resize(outputs_.size());\n  std::transform_inclusive_scan(\n      outputs_.begin(), outputs_.end(), cumulative_.begin(),\n      std::plus<uint64_t>{},\n      [](const std::unique_ptr<Output>& out) { return out->weight; });\n\n  // Validate total weight is positive.\n  if (cumulative_.back() == 0) {\n    throw std::runtime_error(\n        \"ChunkSourceSplitter requires at least one output with positive \"\n        \"weight.\");\n  }\n}\n\nChunkSourceSplitter::~ChunkSourceSplitter() { Stop(); }\n\nvoid ChunkSourceSplitter::Start() {\n  LOG(INFO) << \"Starting ChunkSourceSplitter worker.\";\n  thread_pool_.Enqueue(\n      [this](std::stop_token stop_token) { Worker(stop_token); });\n}\n\nvoid ChunkSourceSplitter::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n  LOG(INFO) << \"Stopping ChunkSourceSplitter.\";\n  thread_pool_.Shutdown();\n  for (auto& out : outputs_) out->queue.Close();\n}\n\nQueueBase* ChunkSourceSplitter::GetOutput(std::string_view name) {\n  auto iter = absl::c_find_if(\n      outputs_, [&](const auto& out) { return out->name == name; });\n  if (iter == outputs_.end()) {\n    throw std::runtime_error(\n        absl::StrCat(\"Unknown output '\", name, \"' for ChunkSourceSplitter.\"));\n  }\n  return &(*iter)->queue;\n}\n\nStageMetricProto ChunkSourceSplitter::FlushMetrics() {\n  StageMetricProto metric;\n  for (auto& out : outputs_) {\n    *metric.add_queue_metrics() = MetricsFromQueue(out->name, out->queue);\n  }\n  return metric;\n}\n\nvoid ChunkSourceSplitter::Worker(std::stop_token stop_token) {\n  // Create producers for each output in this thread.\n  std::vector<Queue<OutputType>::Producer> producers;\n  producers.reserve(outputs_.size());\n  for (auto& out : outputs_) {\n    producers.emplace_back(out->queue.CreateProducer());\n  }\n\n  try {\n    while (true) {\n      InputType item = input_queue()->Get(stop_token);\n      if (item.message_type ==\n          FilePathProvider::MessageType::kInitialScanComplete) {\n        // Broadcast to all outputs.\n        for (auto& prod : producers) {\n          prod.Put(\n              OutputType{.source = nullptr, .message_type = item.message_type},\n              stop_token);\n        }\n        continue;\n      }\n\n      // Share ownership of the ChunkSource with any produced views.\n      std::shared_ptr<ChunkSource> shared_source(std::move(item.source));\n\n      auto per_output_indices = BuildAssignments(shared_source);\n      // Emit only non-empty views, preserving original message type (kFile).\n      for (size_t i = 0; i < outputs_.size(); ++i) {\n        if (per_output_indices[i].empty()) continue;\n        auto view = std::make_unique<ChunkSourceView>(\n            shared_source, std::move(per_output_indices[i]));\n        producers[i].Put(\n            OutputType{.source = std::move(view),\n                       .message_type = FilePathProvider::MessageType::kFile},\n            stop_token);\n      }\n    }\n  } catch (const QueueClosedException&) {\n    // Input queue closed — producers will close queues automatically when\n    // destroyed if this thread holds the last producer.\n    LOG(INFO) << \"ChunkSourceSplitter worker exiting: input closed.\";\n  }\n}\n\nstd::vector<std::vector<uint32_t>> ChunkSourceSplitter::BuildAssignments(\n    const std::shared_ptr<ChunkSource>& source) {\n  const std::string sort_key = source->GetChunkSortKey();\n  const size_t n = source->GetChunkCount();\n\n  // Prepare result containers with a rough reservation.\n  std::vector<std::vector<uint32_t>> indices(outputs_.size());\n\n  for (size_t i = 0; i < n; ++i) {\n    const uint64_t h =\n        static_cast<uint64_t>(absl::Hash<std::pair<std::string, size_t>>{}(\n            std::make_pair(sort_key, i)));\n    const uint64_t r = h % cumulative_.back();\n    // Find the output where cumulative[j-1] <= r < cumulative[j].\n    const auto it = std::upper_bound(cumulative_.begin(), cumulative_.end(), r);\n    const size_t idx = it - cumulative_.begin();\n    indices[idx].push_back(static_cast<uint32_t>(i));\n  }\n\n  return indices;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_splitter.h",
    "content": "#pragma once\n\n#include <memory>\n#include <stop_token>\n#include <string>\n#include <string_view>\n#include <utility>\n#include <vector>\n\n#include \"absl/base/thread_annotations.h\"\n#include \"absl/hash/hash.h\"\n#include \"absl/log/log.h\"\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/chunk_source/chunk_source_view.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Splits an incoming ChunkSource into several ChunkSourceViews based on a\n// deterministic hash of (sort_key, index). Emits to multiple named outputs.\nclass ChunkSourceSplitter\n    : public SingleInputStage<ChunkSourceSplitterConfig, ChunkSourceWithPhase> {\n public:\n  using InputType = ChunkSourceWithPhase;\n  using OutputType = ChunkSourceWithPhase;\n\n  explicit ChunkSourceSplitter(const ChunkSourceSplitterConfig& config);\n  ~ChunkSourceSplitter();\n\n  void Start() override;\n  void Stop() override;\n\n  StageMetricProto FlushMetrics() override;\n  QueueBase* GetOutput(std::string_view name = \"\") override;\n\n private:\n  struct Output {\n    std::string name;\n    uint64_t weight;\n    Queue<OutputType> queue;\n    Output(std::string_view name, uint64_t weight, size_t capacity,\n           OverflowBehavior overflow)\n        : name(name), weight(weight), queue(capacity, overflow) {}\n  };\n\n  void Worker(std::stop_token stop_token);\n\n  // Builds per-output indices given a source; uses absl::Hash on\n  // (sort_key, index) and weights to assign indices.\n  std::vector<std::vector<uint32_t>> BuildAssignments(\n      const std::shared_ptr<ChunkSource>& source);\n\n  std::vector<std::unique_ptr<Output>> outputs_;\n  std::vector<uint64_t> cumulative_;\n  ThreadPool thread_pool_{1};\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_source_splitter_test.cc",
    "content": "#include \"loader/stages/chunk_source_splitter.h\"\n\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"absl/hash/hash.h\"\n#include \"gtest/gtest.h\"\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\nnamespace {\n\n// Simple fixed-count chunk source for testing.\nclass FixedCountChunkSource : public ChunkSource {\n public:\n  FixedCountChunkSource(std::string sort_key, size_t count)\n      : key_(std::move(sort_key)), count_(count) {}\n\n private:\n  std::string GetChunkSortKey() const override { return key_; }\n  size_t GetChunkCount() const override { return count_; }\n  std::optional<std::vector<FrameType>> GetChunkData(size_t) override {\n    return std::vector<FrameType>{FrameType{}};\n  }\n\n  std::string key_;\n  size_t count_;\n};\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view) override { return queue_; }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nTEST(ChunkSourceSplitterTest, SplitsByHashAndWeight) {\n  // Upstream queue.\n  auto input_queue = std::make_unique<Queue<ChunkSourceWithPhase>>(8);\n\n  // Configure splitter with two outputs A:1, B:2.\n  ChunkSourceSplitterConfig cfg;\n  auto* outA = cfg.add_output();\n  outA->set_name(\"A\");\n  outA->set_queue_capacity(8);\n  cfg.add_weight(1);\n  auto* outB = cfg.add_output();\n  outB->set_name(\"B\");\n  outB->set_queue_capacity(8);\n  cfg.add_weight(2);\n\n  ChunkSourceSplitter splitter(cfg);\n  splitter.SetInputs({input_queue.get()});\n\n  splitter.Start();\n\n  // Send a source with known key and count.\n  const std::string key = \"skey\";\n  const size_t count = 100;\n  ChunkSourceWithPhase item;\n  item.source = std::make_unique<FixedCountChunkSource>(key, count);\n  item.message_type = FilePathProvider::MessageType::kFile;\n  auto producer = input_queue->CreateProducer();\n  producer.Put(std::move(item));\n  producer.Close();\n\n  // Compute expected assignment counts using the same hash/weights.\n  const uint64_t total_weight = 3;\n  uint64_t cumA = 1;  // [0]\n  size_t expectedA = 0;\n  size_t expectedB = 0;\n  for (size_t i = 0; i < count; ++i) {\n    const uint64_t h = static_cast<uint64_t>(\n        absl::Hash<std::pair<std::string, size_t>>{}(std::make_pair(key, i)));\n    const uint64_t r = h % total_weight;\n    if (r < cumA)\n      ++expectedA;\n    else\n      ++expectedB;\n  }\n\n  // Read outputs and verify view sizes.\n  auto* qa =\n      dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput(\"A\"));\n  auto* qb =\n      dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput(\"B\"));\n\n  ASSERT_NE(qa, nullptr);\n  ASSERT_NE(qb, nullptr);\n\n  auto msgA = qa->Get();\n  auto msgB = qb->Get();\n  ASSERT_NE(msgA.source, nullptr);\n  ASSERT_NE(msgB.source, nullptr);\n  EXPECT_EQ(msgA.source->GetChunkCount(), expectedA);\n  EXPECT_EQ(msgB.source->GetChunkCount(), expectedB);\n\n  splitter.Stop();\n}\n\nTEST(ChunkSourceSplitterTest, BroadcastsInitialScanComplete) {\n  auto input_queue = std::make_unique<Queue<ChunkSourceWithPhase>>(4);\n\n  ChunkSourceSplitterConfig cfg;\n  auto* outA = cfg.add_output();\n  outA->set_name(\"A\");\n  cfg.add_weight(1);\n  auto* outB = cfg.add_output();\n  outB->set_name(\"B\");\n  cfg.add_weight(1);\n\n  ChunkSourceSplitter splitter(cfg);\n  splitter.SetInputs({input_queue.get()});\n  splitter.Start();\n\n  ChunkSourceWithPhase marker;\n  marker.source = nullptr;\n  marker.message_type = FilePathProvider::MessageType::kInitialScanComplete;\n  auto producer = input_queue->CreateProducer();\n  producer.Put(std::move(marker));\n  producer.Close();\n\n  auto* qa =\n      dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput(\"A\"));\n  auto* qb =\n      dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput(\"B\"));\n\n  auto m1 = qa->Get();\n  auto m2 = qb->Get();\n  EXPECT_EQ(m1.message_type,\n            FilePathProvider::MessageType::kInitialScanComplete);\n  EXPECT_EQ(m2.message_type,\n            FilePathProvider::MessageType::kInitialScanComplete);\n  EXPECT_EQ(m1.source, nullptr);\n  EXPECT_EQ(m2.source, nullptr);\n\n  splitter.Stop();\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_unpacker.cc",
    "content": "#include \"loader/stages/chunk_unpacker.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/container/flat_hash_set.h>\n#include <absl/log/check.h>\n#include <absl/log/log.h>\n#include <absl/random/random.h>\n#include <absl/random/seed_sequences.h>\n\n#include <algorithm>\n#include <cstdint>\n#include <numeric>\n#include <random>\n#include <span>\n#include <utility>\n#include <vector>\n\n#include \"absl/numeric/int128.h\"\n#include \"absl/random/bit_gen_ref.h\"\n#include \"absl/random/random.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/stages/position_sampling.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Deterministically partitions `n` positions into disjoint subsets of size\n// `~p*n`, returning the subset for a given `iteration`. While each selection\n// individually behaves like a Bernoulli sample with probability `p`, the\n// samples are correlated to ensure all positions are selected exactly once over\n// `1/p` iterations. To sample disjoint subsets from the same set of positions,\n// `gen` must be seeded identically for each call with a different `iteration`.\nstd::vector<uint32_t> PickSampledPositions(int32_t n, double p,\n                                           int32_t iteration,\n                                           absl::BitGen& gen) {\n  assert(p > 0.0 && p <= 1.0);\n  double carried_prob = p;\n\n  std::vector<uint32_t> result;\n  absl::flat_hash_set<int32_t> skip_next_round;\n\n  while (true) {\n    int32_t num_this_round = (1.0 - carried_prob) / p + 1;\n    double last_partial_prob = 1 - (carried_prob + (num_this_round - 1) * p);\n    const bool return_this_round = iteration < num_this_round;\n    absl::flat_hash_set<int32_t> skip_this_round(skip_next_round);\n    skip_next_round.clear();\n    for (int32_t i = 0; i < n; ++i) {\n      if (skip_this_round.contains(i)) continue;\n      const double toss = absl::Uniform<double>(gen, 0.0, 1.0);\n      const int32_t value = (toss - carried_prob) / p + 1;\n      if (value == iteration) result.push_back(static_cast<uint32_t>(i));\n      if (value >= num_this_round) {\n        skip_next_round.insert(static_cast<int32_t>(i));\n      }\n    }\n    if (return_this_round) return result;\n    iteration -= num_this_round;\n    carried_prob = p - last_partial_prob;\n  }\n}\n\n// Samples `k` indices from an infinite, deterministically-generated sequence,\n// after an initial `skip`. The sequence is formed by repeatedly shuffling\n// `[0..n-1]` and probabilistically selecting each index. Determinism is\n// achieved by seeding a new PRNG for each shuffled block from a stable\n// `root_seed` and the block's index.\nstd::vector<uint32_t> SampleProbabilisticSequence(\n    uint64_t k, uint64_t skip, std::span<const float> probabilities,\n    absl::BitGen& gen) {\n  const size_t n = probabilities.size();\n  if (n == 0 || k == 0) return {};\n\n  uint64_t skipped_so_far = 0;\n  std::vector<uint32_t> v(n);\n  std::iota(v.begin(), v.end(), 0u);\n  std::vector<uint32_t> result;\n  result.reserve(k);\n\n  while (true) {\n    std::shuffle(v.begin(), v.end(), gen);\n    for (const uint32_t candidate : v) {\n      if (!absl::Bernoulli(gen, probabilities[candidate])) continue;\n      if (skipped_so_far < skip) {\n        skipped_so_far++;\n      } else {\n        result.push_back(candidate);\n        if (result.size() == k) return result;\n      }\n    }\n  }\n}\n\nnamespace {\n\nuint32_t GenerateRunSeed() {\n  absl::BitGen gen(absl::MakeSeedSeq());\n  return absl::Uniform<uint32_t>(gen);\n}\n\n}  // namespace\n\nChunkUnpacker::ChunkUnpacker(const ChunkUnpackerConfig& config)\n    : SingleInputStage<ChunkUnpackerConfig, InputType>(config),\n      config_(config),\n      run_seed_(GenerateRunSeed()),\n      primary_output_queue_(\n          config.output().queue_capacity(),\n          ToOverflowBehavior(config.output().overflow_behavior())),\n      thread_pool_(config.threads(), ThreadPoolOptions{}) {\n  const bool has_rate = config.has_position_sampling_rate();\n  const bool has_count = config.has_position_count();\n  const bool has_prefetch_count = config.has_prefetch_count();\n  const bool has_prefetch_output = config.has_prefetch_output();\n\n  CHECK(has_prefetch_count == has_prefetch_output)\n      << \"prefetch_count and prefetch_output must both be set or both unset.\";\n\n  if (has_prefetch_count) {\n    CHECK(has_count) << \"position_count must be set when using prefetch mode.\";\n    CHECK(!has_rate)\n        << \"position_sampling_rate cannot be used in prefetch mode.\";\n    CHECK(config.position_count() == 1)\n        << \"position_count must equal 1 in prefetch mode, got \"\n        << config.position_count();\n    prefetch_output_queue_.emplace(\n        config.prefetch_output().queue_capacity(),\n        ToOverflowBehavior(config.prefetch_output().overflow_behavior()));\n    if (config.output().name() == config.prefetch_output().name()) {\n      throw std::runtime_error(\n          absl::StrCat(\"ChunkUnpacker output names must be different, got: '\",\n                       config.output().name(), \"'\"));\n    }\n  } else {\n    CHECK(has_rate != has_count)\n        << \"Exactly one of position_sampling_rate or position_count must be \"\n           \"set.\";\n  }\n\n  LOG(INFO) << \"Initializing ChunkUnpacker with \" << config.threads()\n            << \" worker threads\";\n\n  // Initialize thread contexts but don't start worker threads yet.\n  thread_contexts_.reserve(config.threads());\n  for (size_t i = 0; i < config.threads(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n}\n\nChunkUnpacker::~ChunkUnpacker() { Stop(); }\n\nvoid ChunkUnpacker::Start() {\n  LOG(INFO) << \"Starting ChunkUnpacker worker threads.\";\n  for (size_t i = 0; i < thread_contexts_.size(); ++i) {\n    thread_pool_.Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, thread_contexts_[i].get());\n    });\n  }\n}\n\nvoid ChunkUnpacker::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n\n  LOG(INFO) << \"Stopping ChunkUnpacker.\";\n  thread_pool_.Shutdown();\n  primary_output_queue_.Close();\n  if (prefetch_output_queue_) prefetch_output_queue_->Close();\n  LOG(INFO) << \"ChunkUnpacker stopped.\";\n}\n\nQueueBase* ChunkUnpacker::GetOutput(std::string_view name) {\n  if (name == config_.output().name()) return &primary_output_queue_;\n  if (config_.has_prefetch_output() &&\n      name == config_.prefetch_output().name()) {\n    return &*prefetch_output_queue_;\n  }\n  std::string available = absl::StrCat(\"'\", config_.output().name(), \"'\");\n  if (config_.has_prefetch_output()) {\n    absl::StrAppend(&available, \", '\", config_.prefetch_output().name(), \"'\");\n  }\n  throw std::runtime_error(absl::StrCat(\"ChunkUnpacker unknown output '\", name,\n                                        \"'. Available outputs: \", available));\n}\n\nnamespace {\nstd::vector<float> FramesToProbabilities(std::span<const FrameType> frames,\n                                         const PositionSamplingConfig& config) {\n  std::vector<float> probabilities;\n  probabilities.reserve(frames.size());\n  absl::c_transform(frames, std::back_inserter(probabilities),\n                    [&](const FrameType& frame) {\n                      return ComputePositionSamplingWeight(frame, config);\n                    });\n  const float max_prob = *absl::c_max_element(probabilities);\n  if (max_prob > 0.0f) {\n    absl::c_transform(probabilities, probabilities.begin(),\n                      [max_prob](float p) { return p / max_prob; });\n  } else {\n    absl::c_fill(probabilities, 1.0f);\n  }\n  return probabilities;\n}\n}  // namespace\n\nvoid ChunkUnpacker::Worker(std::stop_token stop_token, ThreadContext* context) {\n  // Create a local producer for this worker thread.\n  auto primary_producer = primary_output_queue_.CreateProducer();\n  std::optional<decltype(prefetch_output_queue_->CreateProducer())>\n      prefetch_producer;\n  if (prefetch_output_queue_.has_value()) {\n    prefetch_producer.emplace(prefetch_output_queue_->CreateProducer());\n  }\n\n  try {\n    while (true) {\n      auto chunk = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return input_queue()->Get(stop_token);\n      }();\n\n      absl::BitGen gen(\n          std::seed_seq{run_seed_, static_cast<uint32_t>(chunk.global_index)});\n      std::vector<uint32_t> positions;\n      if (config_.has_position_sampling_rate()) {\n        positions = PickSampledPositions(\n            static_cast<int32_t>(chunk.frames.size()),\n            config_.position_sampling_rate(), chunk.use_count, gen);\n      } else {\n        auto probabilities =\n            FramesToProbabilities(chunk.frames, config_.position_sampling());\n        positions = SampleProbabilisticSequence(\n            config_.position_count() + config_.prefetch_count(),\n            config_.position_count() * chunk.use_count, probabilities, gen);\n      }\n\n      if (config_.has_prefetch_count()) {\n        // Prefetch mode: output first position to primary, rest to prefetch.\n        if (!positions.empty()) {\n          LoadMetricPauser pauser(context->load_metric_updater);\n          primary_producer.Put(std::move(chunk.frames[positions[0]]),\n                               stop_token);\n        }\n        if (positions.size() > 1) {\n          CacheRequest cache_request;\n          cache_request.global_index = chunk.global_index;\n          cache_request.next_use = chunk.use_count + 1;\n          cache_request.items.reserve(positions.size() - 1);\n          for (size_t i = 1; i < positions.size(); ++i) {\n            cache_request.items.push_back(chunk.frames[positions[i]]);\n          }\n          LoadMetricPauser pauser(context->load_metric_updater);\n          prefetch_producer->Put(std::move(cache_request), stop_token);\n        }\n      } else {\n        // Normal mode: output all positions to primary.\n        for (uint32_t pos : positions) {\n          LoadMetricPauser pauser(context->load_metric_updater);\n          primary_producer.Put(std::move(chunk.frames[pos]), stop_token);\n        }\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"ChunkUnpacker worker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"ChunkUnpacker worker stopping, request cancelled.\";\n  }\n}\n\nStageMetricProto ChunkUnpacker::FlushMetrics() {\n  StageMetricProto stage_metric;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(aggregated_load);\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(config_.output().name(), primary_output_queue_);\n  if (prefetch_output_queue_.has_value()) {\n    *stage_metric.add_queue_metrics() = MetricsFromQueue(\n        config_.prefetch_output().name(), *prefetch_output_queue_);\n  }\n  return stage_metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_unpacker.h",
    "content": "// ABOUTME: Stage that unpacks chunks into FrameType frames.\n// ABOUTME: Converts stream of std::string chunks to FrameType stream.\n#pragma once\n\n#include <atomic>\n#include <cstdint>\n#include <memory>\n#include <stop_token>\n#include <vector>\n\n#include \"absl/random/random.h\"\n#include \"absl/types/optional.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/frame_type.h\"\n#include \"loader/stages/stage.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Worker pool that unpacks chunks into frames.\n// Takes parsed TrainingChunk objects as input and outputs individual\n// FrameType frames.\nclass ChunkUnpacker\n    : public SingleInputStage<ChunkUnpackerConfig, TrainingChunk> {\n public:\n  using InputType = TrainingChunk;\n\n  explicit ChunkUnpacker(const ChunkUnpackerConfig& config);\n  ~ChunkUnpacker();\n\n  void Start() override;\n  void Stop() override;\n  QueueBase* GetOutput(std::string_view name) override;\n  StageMetricProto FlushMetrics() override;\n\n  Queue<FrameType>* output_queue() { return &primary_output_queue_; }\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, ThreadContext* context);\n\n  const ChunkUnpackerConfig config_;\n  const uint32_t run_seed_;\n  Queue<FrameType> primary_output_queue_;\n  std::optional<Queue<CacheRequest>> prefetch_output_queue_;\n  // thread_contexts_ must be declared before thread_pool_ to ensure\n  // thread_pool_ is destroyed first (stopping threads before contexts).\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  ThreadPool thread_pool_;\n};\n\nstd::vector<uint32_t> PickSampledPositions(int32_t n, double p,\n                                           int32_t iteration,\n                                           absl::BitGen& gen);\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/chunk_unpacker_test.cc",
    "content": "#include \"loader/stages/chunk_unpacker.h\"\n\n#include <iterator>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"absl/algorithm/container.h\"\n#include \"absl/container/flat_hash_set.h\"\n#include \"absl/random/random.h\"\n#include \"absl/random/seed_sequences.h\"\n#include \"gtest/gtest.h\"\n#include \"libs/lc0/src/trainingdata/trainingdata_v6.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nclass ChunkUnpackerTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    input_queue_ = std::make_unique<Queue<TrainingChunk>>(10);\n    config_.set_threads(1);\n    config_.mutable_output()->set_queue_capacity(10);\n    config_.set_position_sampling_rate(1.0f);\n  }\n\n  FrameType CreateTestFrame(uint32_t version) {\n    FrameType frame{};\n    frame.version = version;\n    frame.input_format = 3;\n    frame.root_q = 0.5f;\n    return frame;\n  }\n\n  TrainingChunk MakeChunk(std::vector<FrameType> frames,\n                          std::string sort_key = \"source\", size_t index = 0,\n                          uint32_t use = 0) {\n    TrainingChunk chunk;\n    chunk.sort_key = std::move(sort_key);\n    chunk.index_within_sort_key = index;\n    chunk.use_count = use;\n    chunk.frames = std::move(frames);\n    return chunk;\n  }\n\n  std::unique_ptr<Queue<TrainingChunk>> input_queue_;\n  ChunkUnpackerConfig config_;\n};\n\nTEST_F(ChunkUnpackerTest, UnpacksSingleFrame) {\n  ChunkUnpacker unpacker(config_);\n  unpacker.SetInputs({input_queue_.get()});\n  unpacker.Start();\n\n  FrameType test_frame = CreateTestFrame(6);\n  auto producer = input_queue_->CreateProducer();\n  producer.Put(MakeChunk({test_frame}));\n  producer.Close();\n\n  auto output_frame = unpacker.output_queue()->Get();\n  EXPECT_EQ(output_frame.version, 6);\n  EXPECT_EQ(output_frame.input_format, 3);\n  EXPECT_EQ(output_frame.root_q, 0.5f);\n}\n\nTEST_F(ChunkUnpackerTest, UnpacksMultipleFrames) {\n  ChunkUnpacker unpacker(config_);\n  unpacker.SetInputs({input_queue_.get()});\n  unpacker.Start();\n\n  std::vector<FrameType> test_frames = {CreateTestFrame(6), CreateTestFrame(7),\n                                        CreateTestFrame(8)};\n  auto producer = input_queue_->CreateProducer();\n  producer.Put(MakeChunk(test_frames));\n  producer.Close();\n\n  std::vector<uint32_t> actual_versions;\n  actual_versions.reserve(test_frames.size());\n  for (size_t i = 0; i < test_frames.size(); ++i) {\n    auto output_frame = unpacker.output_queue()->Get();\n    actual_versions.push_back(output_frame.version);\n    EXPECT_EQ(output_frame.input_format, 3);\n    EXPECT_EQ(output_frame.root_q, 0.5f);\n  }\n\n  std::vector<uint32_t> expected_versions;\n  expected_versions.reserve(test_frames.size());\n  for (const auto& frame : test_frames) {\n    expected_versions.push_back(frame.version);\n  }\n\n  absl::c_sort(actual_versions);\n  absl::c_sort(expected_versions);\n  EXPECT_EQ(actual_versions, expected_versions);\n}\n\nTEST_F(ChunkUnpackerTest, UnpacksMultipleChunks) {\n  ChunkUnpacker unpacker(config_);\n  unpacker.SetInputs({input_queue_.get()});\n  unpacker.Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Send first chunk with 2 frames\n  std::vector<FrameType> chunk1_frames = {CreateTestFrame(10),\n                                          CreateTestFrame(11)};\n  producer.Put(MakeChunk(chunk1_frames, \"source\", 0));\n\n  // Send second chunk with 1 frame\n  std::vector<FrameType> chunk2_frames = {CreateTestFrame(12)};\n  producer.Put(MakeChunk(chunk2_frames, \"source\", 1));\n\n  producer.Close();\n\n  // Verify all frames are output\n  std::vector<uint32_t> expected_versions = {10, 11, 12};\n  std::vector<uint32_t> actual_versions;\n  actual_versions.reserve(expected_versions.size());\n  for (size_t i = 0; i < expected_versions.size(); ++i) {\n    auto output_frame = unpacker.output_queue()->Get();\n    actual_versions.push_back(output_frame.version);\n    EXPECT_EQ(output_frame.input_format, 3);\n    EXPECT_EQ(output_frame.root_q, 0.5f);\n  }\n\n  absl::c_sort(actual_versions);\n  absl::c_sort(expected_versions);\n  EXPECT_EQ(actual_versions, expected_versions);\n}\n\nTEST_F(ChunkUnpackerTest, HandlesEmptyChunk) {\n  ChunkUnpacker unpacker(config_);\n  unpacker.SetInputs({input_queue_.get()});\n  unpacker.Start();\n\n  auto producer = input_queue_->CreateProducer();\n  TrainingChunk empty_chunk;\n  empty_chunk.sort_key = \"source\";\n  empty_chunk.index_within_sort_key = 0;\n  producer.Put(std::move(empty_chunk));\n  producer.Close();\n\n  // Should not produce any output frames, queue should close\n  EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException);\n}\n\nTEST_F(ChunkUnpackerTest, HandlesQueueClosure) {\n  ChunkUnpacker unpacker(config_);\n  unpacker.SetInputs({input_queue_.get()});\n  unpacker.Start();\n\n  // Close input queue without sending data\n  input_queue_->Close();\n\n  // Output queue should eventually close\n  EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException);\n}\n\nTEST(PickSampledPositionsTest, Deterministic) {\n  absl::BitGen gen1(absl::SeedSeq{42});\n  std::vector<uint32_t> result1 = PickSampledPositions(1000, 0.1, 5, gen1);\n\n  absl::BitGen gen2(absl::SeedSeq{42});\n  std::vector<uint32_t> result2 = PickSampledPositions(1000, 0.1, 5, gen2);\n\n  EXPECT_EQ(result1, result2);\n}\n\nTEST(PickSampledPositionsTest, FullBucketFirstRound) {\n  absl::BitGen gen(absl::SeedSeq{42});\n  const uint32_t n = 10000;\n  const double p = 0.1;\n  std::vector<uint32_t> result = PickSampledPositions(n, p, 0, gen);\n  // Expect size to be around n*p.\n  EXPECT_NEAR(result.size(), n * p, n * p * 0.25);\n}\n\nTEST(PickSampledPositionsTest, DisjointBuckets) {\n  absl::BitGen gen(absl::SeedSeq{42});\n  const uint32_t n = 1000;\n  const double p = 0.1;\n\n  std::vector<uint32_t> bucket1 = PickSampledPositions(n, p, 0, gen);\n  absl::c_sort(bucket1);\n\n  // The generator state is now changed. For the next bucket, we need a fresh\n  // one with the same seed to test the logic for a different iteration.\n  absl::BitGen gen2(absl::SeedSeq{42});\n  std::vector<uint32_t> bucket2 = PickSampledPositions(n, p, 1, gen2);\n  absl::c_sort(bucket2);\n\n  std::vector<uint32_t> intersection;\n  absl::c_set_intersection(bucket1, bucket2, std::back_inserter(intersection));\n\n  EXPECT_TRUE(intersection.empty());\n}\n\nTEST(PickSampledPositionsTest, PartialBucketElementsAreReturned) {\n  absl::BitGen gen(absl::SeedSeq{42});\n  const uint32_t n = 1000;\n  const double p = 0.8;  // remainder 0.2\n\n  // In round 1, elements with toss >= 0.8 are for iteration 1.\n  absl::BitGen gen1(absl::SeedSeq{42});\n  std::vector<uint32_t> expected_from_round1;\n  for (uint32_t i = 0; i < n; ++i) {\n    double toss = absl::Uniform<double>(gen1, 0.0, 1.0);\n    if (toss >= 0.8) {\n      expected_from_round1.push_back(i);\n    }\n  }\n  absl::c_sort(expected_from_round1);\n\n  std::vector<uint32_t> result = PickSampledPositions(n, p, 1, gen);\n  absl::c_sort(result);\n\n  // Check if all elements from round 1 are in the final result.\n  // This will fail with the current implementation because they are discarded.\n  std::vector<uint32_t> intersection;\n  absl::c_set_intersection(expected_from_round1, result,\n                           std::back_inserter(intersection));\n\n  EXPECT_GT(expected_from_round1.size(), 50);  // High probability for n=1000\n  EXPECT_EQ(intersection.size(), expected_from_round1.size());\n}\n\nTEST(PickSampledPositionsTest, PartialBucketCompletedSize) {\n  absl::BitGen gen(absl::SeedSeq{42});\n  const uint32_t n = 10000;\n  const double p = 0.8;  // remainder 0.2\n\n  std::vector<uint32_t> result = PickSampledPositions(n, p, 1, gen);\n\n  // Expect size to be around n*p.\n  // This will fail due to incorrect probability calculation for completion.\n  EXPECT_NEAR(result.size(), n * p, n * p * 0.25);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/file_path_provider.cc",
    "content": "#include \"loader/stages/file_path_provider.h\"\n\n#include <absl/cleanup/cleanup.h>\n#include <absl/container/flat_hash_set.h>\n#include <absl/log/check.h>\n#include <absl/log/log.h>\n#include <absl/synchronization/mutex.h>\n#include <sys/epoll.h>\n#include <unistd.h>\n\n#include <array>\n#include <cerrno>\n#include <chrono>\n#include <cstring>\n#include <filesystem>\n#include <stdexcept>\n#include <string_view>\n#include <thread>\n#include <utility>\n\n#include \"loader/data_loader_metrics.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\nbool ShouldSkipName(std::string_view name) {\n  return !name.empty() && name.front() == '.';\n}\n\nbool ShouldSkipPathEntry(const FilePathProvider::Path& path) {\n  return ShouldSkipName(path.filename().string());\n}\n\n}  // namespace\n\nFilePathProvider::FilePathProvider(const FilePathProviderConfig& config)\n    : SingleOutputStage<File>(config.output()),\n      directory_(config.directory()),\n      producer_(output_queue()->CreateProducer()),\n      load_metric_updater_() {\n  LOG(INFO) << \"Initializing FilePathProvider for directory: \"\n            << config.directory();\n  inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);\n  CHECK_NE(inotify_fd_, -1)\n      << \"Failed to initialize inotify: \" << strerror(errno);\n}\n\nFilePathProvider::~FilePathProvider() {\n  LOG(INFO) << \"FilePathProvider shutting down.\";\n  Stop();\n  if (inotify_fd_ != -1) close(inotify_fd_);\n  LOG(INFO) << \"FilePathProvider shutdown complete.\";\n}\n\nvoid FilePathProvider::SetInputs(absl::Span<QueueBase* const> inputs) {\n  if (!inputs.empty()) {\n    throw std::runtime_error(\n        \"FilePathProvider expects no inputs, but received \" +\n        std::to_string(inputs.size()));\n  }\n}\n\nvoid FilePathProvider::Start() {\n  LOG(INFO) << \"Starting FilePathProvider monitoring thread.\";\n  thread_pool_.Enqueue(\n      [this](std::stop_token stop_token) { Worker(stop_token); });\n}\n\nvoid FilePathProvider::Stop() {\n  if (stop_source_.stop_requested()) return;\n  LOG(INFO) << \"Stopping FilePathProvider.\";\n  LOG(INFO) << \"Stopping all watches...\";\n  for (const auto& [wd, path] : watch_descriptors_) {\n    inotify_rm_watch(inotify_fd_, wd);\n  }\n  watch_descriptors_.clear();\n  stop_source_.request_stop();\n  thread_pool_.Shutdown();\n  producer_.Close();\n}\n\nStageMetricProto FilePathProvider::FlushMetrics() {\n  StageMetricProto stage_metric;\n  auto load_metrics = load_metric_updater_.FlushMetrics();\n  load_metrics.set_name(\"load\");\n  *stage_metric.add_load_metrics() = std::move(load_metrics);\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *output_queue());\n  return stage_metric;\n}\n\nvoid FilePathProvider::AddDirectory(const Path& directory,\n                                    std::stop_token stop_token) {\n  ScanDirectoryWithWatch(directory, stop_token);\n\n  LOG(INFO) << \"FilePathProvider registered \" << directory\n            << \"; active watch descriptors: \" << watch_descriptors_.size();\n\n  // Signal that initial scan is complete\n  LOG(INFO) << \"FilePathProvider initial scan complete\";\n  producer_.Put(\n      {{.filepath = Path{}, .message_type = MessageType::kInitialScanComplete}},\n      stop_token);\n}\n\nvoid FilePathProvider::ScanDirectoryWithWatch(const Path& directory,\n                                              std::stop_token stop_token) {\n  // Step 1: Set up watch first\n  int wd = inotify_add_watch(inotify_fd_, directory.c_str(),\n                             IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE |\n                                 IN_DELETE | IN_DELETE_SELF | IN_MOVE);\n  CHECK_NE(wd, -1) << \"Failed to add inotify watch for \" << directory << \": \"\n                   << strerror(errno);\n  watch_descriptors_[wd] = directory;\n\n  // Step 2: Scan directory non-recursively, remembering files and subdirs\n  std::vector<Path> files;\n  std::vector<Path> subdirectories;\n  std::error_code ec;\n  auto iterator = std::filesystem::directory_iterator(directory, ec);\n  CHECK(!ec) << \"Failed to iterate directory \" << directory << \": \"\n             << ec.message();\n\n  for (const auto& entry : iterator) {\n    const Path entry_path = entry.path();\n    if (ShouldSkipPathEntry(entry_path)) continue;\n\n    if (entry.is_regular_file(ec) && !ec) {\n      files.push_back(entry_path);\n    } else if (entry.is_directory(ec) && !ec) {\n      subdirectories.push_back(entry_path);\n    }\n  }\n\n  const size_t initial_file_count = files.size();\n  const size_t subdirectory_count = subdirectories.size();\n  LOG(INFO) << \"FilePathProvider scanned \" << directory << \" discovering \"\n            << initial_file_count << \" file(s) and \" << subdirectory_count\n            << \" subdirectory(ies) before watch reconciliation.\";\n\n  // Send notifications for discovered files\n  constexpr size_t kBatchSize = 10000;\n  std::vector<File> batch;\n  batch.reserve(kBatchSize);\n\n  auto flush_batch = [&]() {\n    if (batch.empty()) return;\n    producer_.Put(batch, stop_token);\n    batch.clear();\n  };\n\n  for (const auto& filepath : files) {\n    batch.push_back(\n        {.filepath = filepath.string(), .message_type = MessageType::kFile});\n    if (batch.size() >= kBatchSize) flush_batch();\n  }\n\n  if (initial_file_count > 0) {\n    LOG(INFO) << \"FilePathProvider enqueued \" << initial_file_count\n              << \" file(s) from initial scan of \" << directory;\n  }\n\n  // Step 3: Read from watch descriptor, skipping already discovered files\n  ProcessWatchEventsForNewItems(files);\n\n  // Step 4: Clean the files vector to save memory\n  files.clear();\n\n  // Step 5: Recursively call for subdirectories\n  for (const auto& subdir : subdirectories) {\n    if (stop_token.stop_requested()) return;\n    ScanDirectoryWithWatch(subdir, stop_token);\n  }\n\n  // Flush any remaining files\n  flush_batch();\n}\n\nvoid FilePathProvider::ProcessWatchEventsForNewItems(\n    const std::vector<Path>& known_files) {\n  // Create a set for fast lookup of already discovered files\n  absl::flat_hash_set<std::string> known_file_set;\n  for (const auto& file : known_files) {\n    known_file_set.insert(file.string());\n  }\n\n  // Process any events that may have occurred during scanning\n  std::array<char, 4096> buffer;\n  std::vector<File> new_files;\n\n  while (true) {\n    ssize_t length = read(inotify_fd_, buffer.data(), buffer.size());\n    if (length <= 0) break;  // No more events to process\n\n    ssize_t offset = 0;\n    while (offset < length) {\n      const struct inotify_event* event =\n          reinterpret_cast<const struct inotify_event*>(buffer.data() + offset);\n\n      const bool skip_entry = event->len > 0 && ShouldSkipName(event->name);\n\n      // Only process file creation/write events, skip already known files\n      if ((event->mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 &&\n          event->len > 0 && !skip_entry) {\n        const Path directory(watch_descriptors_.at(event->wd));\n        Path filepath = directory / event->name;\n        std::string filepath_string = filepath.string();\n\n        // Only add if we haven't seen this file before\n        if (!known_file_set.contains(filepath_string)) {\n          new_files.push_back({.filepath = std::move(filepath_string),\n                               .message_type = MessageType::kFile});\n        }\n      }\n\n      offset += sizeof(struct inotify_event) + event->len;\n    }\n  }\n\n  // Send notifications for any new files discovered through watch events\n  if (!new_files.empty()) {\n    LOG(INFO) << \"FilePathProvider observed \" << new_files.size()\n              << \" new file(s) while reconciling race events.\";\n    producer_.Put(new_files);\n  }\n}\n\nvoid FilePathProvider::AddWatchRecursive(const Path& path) {\n  // Add watch for current directory\n  int wd = inotify_add_watch(inotify_fd_, path.c_str(),\n                             IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE |\n                                 IN_DELETE | IN_DELETE_SELF | IN_MOVE);\n  CHECK_NE(wd, -1) << \"Failed to add inotify watch for \" << path << \": \"\n                   << strerror(errno);\n  watch_descriptors_[wd] = path;\n\n  // Recursively add watches for subdirectories\n  std::error_code ec;\n  auto iterator = std::filesystem::directory_iterator(path, ec);\n  CHECK(!ec) << \"Failed to iterate directory \" << path << \": \" << ec.message();\n\n  for (const auto& entry : iterator) {\n    const Path entry_path = entry.path();\n    if (ShouldSkipPathEntry(entry_path)) continue;\n    if (!entry.is_directory(ec) || ec) continue;\n    AddWatchRecursive(entry_path);\n  }\n}\n\nvoid FilePathProvider::RemoveWatchRecursive(const Path& base) {\n  absl::erase_if(watch_descriptors_, [&](const auto& pair) {\n    const auto& [wd, path] = pair;\n    const auto mismatch_iter = absl::c_mismatch(base, path).first;\n    // If path is not a subdirectory (or equal) of base, skip.\n    if (mismatch_iter != base.end()) return false;\n    inotify_rm_watch(inotify_fd_, wd);\n    return true;\n  });\n}\n\nvoid FilePathProvider::Worker(std::stop_token stop_token) {\n  // Perform directory scanning in background thread\n  AddDirectory(directory_, stop_token);\n\n  int epoll_fd = epoll_create1(EPOLL_CLOEXEC);\n  CHECK_NE(epoll_fd, -1) << \"Failed to create epoll fd: \" << strerror(errno);\n  absl::Cleanup epoll_cleanup([epoll_fd]() { close(epoll_fd); });\n\n  struct epoll_event event;\n  event.events = EPOLLIN;\n  event.data.fd = inotify_fd_;\n  CHECK_EQ(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, inotify_fd_, &event), 0)\n      << \"Failed to add inotify fd to epoll: \" << strerror(errno);\n\n  while (!stop_token.stop_requested()) {\n    {\n      LoadMetricPauser pauser(load_metric_updater_);\n      std::this_thread::sleep_for(std::chrono::milliseconds(50));\n      if (stop_token.stop_requested()) {\n        pauser.DoNotResume();\n        break;\n      }\n    }\n\n    struct epoll_event event;\n    int nfds = epoll_wait(epoll_fd, &event, 1, 0);  // Non-blocking check\n    CHECK_NE(nfds, -1) << \"epoll_wait failed: \" << strerror(errno);\n    if (nfds == 0) continue;  // No events.\n\n    do {\n      assert(nfds == 1 && event.data.fd == inotify_fd_);\n      ProcessInotifyEvents(producer_, stop_token);\n      nfds = epoll_wait(epoll_fd, &event, 1, 0);\n    } while (nfds > 0);\n  }\n}\n\nvoid FilePathProvider::ProcessInotifyEvents(Queue<File>::Producer& producer,\n                                            std::stop_token stop_token) {\n  constexpr size_t kNotifyBatchSize = 10000;\n  std::vector<File> files;\n  std::array<char, 4096> buffer;\n\n  auto flush_batch = [&]() {\n    if (files.empty()) return;\n    producer.Put(files, stop_token);\n    files.clear();\n  };\n\n  while (true) {\n    ssize_t length = read(inotify_fd_, buffer.data(), buffer.size());\n    if (length <= 0) break;  // No more events to process\n\n    ssize_t offset = 0;\n    while (offset < length) {\n      const struct inotify_event* event =\n          reinterpret_cast<const struct inotify_event*>(buffer.data() + offset);\n      auto file = ProcessInotifyEvent(*event, stop_token);\n      if (file) files.push_back(*file);\n      if (files.size() >= kNotifyBatchSize) flush_batch();\n      offset += sizeof(struct inotify_event) + event->len;\n    }\n  }\n\n  flush_batch();  // Flush any remaining files in the batch\n}\n\nauto FilePathProvider::ProcessInotifyEvent(const struct inotify_event& event,\n                                           std::stop_token stop_token)\n    -> std::optional<File> {\n  if (event.mask & IN_IGNORED) return std::nullopt;\n\n  const Path directory(watch_descriptors_.at(event.wd));\n  const bool has_name = event.len > 0 && event.name[0] != '\\0';\n  const bool skip_entry = has_name && ShouldSkipName(event.name);\n  const Path filepath = has_name ? directory / event.name : directory;\n\n  // Handle different event types\n  if ((event.mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 && has_name &&\n      !skip_entry) {\n    // File finished writing or moved into directory\n    return File{.filepath = filepath, .message_type = MessageType::kFile};\n  }\n\n  constexpr uint32_t kDirCreateMask = IN_CREATE | IN_ISDIR;\n  constexpr uint32_t kDirDeleteMask = IN_DELETE | IN_ISDIR;\n  if ((event.mask & kDirCreateMask) == kDirCreateMask) {\n    if (!has_name || skip_entry) return std::nullopt;\n    ScanDirectoryWithWatch(filepath, stop_token);\n  } else if ((event.mask & kDirDeleteMask) == kDirDeleteMask) {\n    if (!has_name || skip_entry) return std::nullopt;\n    // Directory deleted - remove all watches for it and subdirectories\n    RemoveWatchRecursive(filepath);\n  } else if (event.mask & IN_DELETE_SELF) {\n    RemoveWatchRecursive(directory);\n  }\n\n  return std::nullopt;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/file_path_provider.h",
    "content": "#pragma once\n\n#include <absl/base/thread_annotations.h>\n#include <absl/container/flat_hash_map.h>\n#include <absl/log/log.h>\n#include <absl/synchronization/mutex.h>\n#include <sys/inotify.h>\n\n#include <filesystem>\n#include <functional>\n#include <span>\n#include <stop_token>\n#include <string>\n#include <vector>\n\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n#include \"utils/metrics/printer.h\"\n#include \"utils/metrics/statistics_metric.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Message types for FilePathProvider output.\nenum class FilePathProviderMessageType {\n  kFile,                // File discovered (initial scan or inotify)\n  kInitialScanComplete  // Initial scan is complete (empty filepath)\n};\n\n// Output type for FilePathProvider.\nstruct FilePathProviderFile {\n  std::filesystem::path filepath;\n  FilePathProviderMessageType message_type;\n};\n\n// This class watches for new files in a directory (recursively) and notifies\n// registered observers when new files are either closed after writing or\n// renamed into.\n// Uses background thread to monitor the directory.\nclass FilePathProvider : public SingleOutputStage<FilePathProviderFile> {\n public:\n  using Path = std::filesystem::path;\n  using MessageType = FilePathProviderMessageType;\n  using File = FilePathProviderFile;\n\n  explicit FilePathProvider(const FilePathProviderConfig& config);\n  ~FilePathProvider();\n\n  // Starts monitoring the directory\n  void Start() override;\n\n  // Closes the output queue, signaling completion\n  void Stop() override;\n\n  // Returns current metrics and clears them.\n  StageMetricProto FlushMetrics() override;\n\n  // FilePathProvider has no inputs.\n  void SetInputs(absl::Span<QueueBase* const> inputs) override;\n\n private:\n  // Starts monitoring the directory.\n  void AddDirectory(const Path& directory, std::stop_token stop_token);\n\n  void Worker(std::stop_token stop_token);\n  void AddWatchRecursive(const Path& path);\n  void RemoveWatchRecursive(const Path& path);\n  void ScanDirectoryWithWatch(const Path& directory,\n                              std::stop_token stop_token);\n  void ProcessWatchEventsForNewItems(const std::vector<Path>& known_files);\n  void ProcessInotifyEvents(Queue<File>::Producer& producer,\n                            std::stop_token stop_token);\n  std::optional<File> ProcessInotifyEvent(const struct inotify_event& event,\n                                          std::stop_token stop_token);\n\n  int inotify_fd_;\n  // Watch descriptor to directory path.\n  absl::flat_hash_map<int, Path> watch_descriptors_;\n\n  Path directory_;  // Directory to monitor\n  Queue<File>::Producer producer_;\n\n  LoadMetricUpdater load_metric_updater_;\n  std::stop_source stop_source_;\n  ThreadPool thread_pool_{1, ThreadPoolOptions{}, stop_source_};\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/file_path_provider_main.cc",
    "content": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n\n#include <iostream>\n#include <string>\n#include <thread>\n\n#include \"loader/stages/file_path_provider.h\"\n\nABSL_FLAG(std::string, directory, \"\", \"Directory to monitor for files\");\n\nint main(int argc, char* argv[]) {\n  absl::ParseCommandLine(argc, argv);\n  absl::InitializeLog();\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  std::string directory = absl::GetFlag(FLAGS_directory);\n  if (directory.empty()) {\n    std::cerr << \"Usage: \" << argv[0] << \" --directory=<directory>\"\n              << std::endl;\n    return 1;\n  }\n\n  LOG(INFO) << \"Starting to monitor directory: \" << directory;\n  lczero::training::FilePathProviderConfig config;\n  config.mutable_output()->set_queue_capacity(16);\n  config.set_directory(directory);\n  lczero::training::FilePathProvider file_path_provider(config);\n\n  // Consumer thread to read from the queue\n  std::thread consumer_thread([&file_path_provider]() {\n    auto* queue = file_path_provider.output_queue();\n    try {\n      while (true) {\n        auto file = queue->Get();\n        const char* type_str =\n            (file.message_type ==\n             lczero::training::FilePathProvider::MessageType::kFile)\n                ? \"File\"\n                : \"Initial scan complete\";\n        LOG(INFO) << \"File \" << type_str << \": \" << file.filepath;\n      }\n    } catch (const lczero::QueueClosedException&) {\n      LOG(INFO) << \"Queue closed, consumer thread exiting\";\n    }\n  });\n\n  LOG(INFO) << \"Monitoring for files... Press Enter to exit.\";\n  std::cin.get();\n\n  // Close the queue and wait for consumer to finish\n  file_path_provider.Stop();\n  consumer_thread.join();\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/loader/stages/file_path_provider_test.cc",
    "content": "#include \"loader/stages/file_path_provider.h\"\n\n#include <gtest/gtest.h>\n\n#include <chrono>\n#include <filesystem>\n#include <fstream>\n#include <string>\n#include <unordered_set>\n#include <vector>\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\nFilePathProviderConfig MakeConfig(const std::filesystem::path& directory) {\n  FilePathProviderConfig config;\n  config.mutable_output()->set_queue_capacity(128);\n  config.set_directory(directory.string());\n  return config;\n}\n\nstd::string RelativeTo(const std::filesystem::path& base,\n                       const std::filesystem::path& target) {\n  return target.lexically_relative(base).generic_string();\n}\n\n}  // namespace\n\nclass FilePathProviderTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    test_dir_ =\n        std::filesystem::temp_directory_path() /\n        (\"file_path_provider_test_\" +\n         std::to_string(\n             std::chrono::steady_clock::now().time_since_epoch().count()));\n    std::filesystem::create_directories(test_dir_);\n  }\n\n  void TearDown() override {\n    if (std::filesystem::exists(test_dir_)) {\n      std::filesystem::remove_all(test_dir_);\n    }\n  }\n\n  void CreateFile(const std::filesystem::path& path,\n                  const std::string& content = \"payload\") {\n    std::filesystem::create_directories(path.parent_path());\n    std::ofstream file(path);\n    file << content;\n  }\n\n  void CreateDirectory(const std::filesystem::path& path) {\n    std::filesystem::create_directories(path);\n  }\n\n  std::vector<std::filesystem::path> DrainInitialScan(\n      Queue<FilePathProvider::File>* queue) {\n    std::vector<std::filesystem::path> files;\n    while (true) {\n      auto message = queue->Get();\n      if (message.message_type ==\n          FilePathProvider::MessageType::kInitialScanComplete) {\n        EXPECT_TRUE(message.filepath.empty());\n        break;\n      }\n      if (message.message_type != FilePathProvider::MessageType::kFile) {\n        ADD_FAILURE() << \"Unexpected message type in initial scan.\";\n        continue;\n      }\n      files.push_back(message.filepath);\n    }\n    return files;\n  }\n\n  FilePathProvider::File AwaitNextFile(Queue<FilePathProvider::File>* queue) {\n    while (true) {\n      auto message = queue->Get();\n      if (message.message_type == FilePathProvider::MessageType::kFile) {\n        return message;\n      }\n      if (message.message_type !=\n          FilePathProvider::MessageType::kInitialScanComplete) {\n        ADD_FAILURE()\n            << \"Unexpected message type while waiting for file notification.\";\n      }\n    }\n  }\n\n  FilePathProviderConfig Config() const { return MakeConfig(test_dir_); }\n\n  std::filesystem::path test_dir_;\n};\n\nTEST_F(FilePathProviderTest, ConstructorCreatesQueue) {\n  FilePathProvider provider(Config());\n  provider.Start();\n\n  auto* queue = provider.output_queue();\n  ASSERT_NE(queue, nullptr);\n  EXPECT_EQ(queue->Capacity(), 128);\n\n  auto message = queue->Get();\n  EXPECT_EQ(message.message_type,\n            FilePathProvider::MessageType::kInitialScanComplete);\n  EXPECT_TRUE(message.filepath.empty());\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, InitialScanFindsVisibleFiles) {\n  CreateFile(test_dir_ / \"file1.txt\");\n  CreateFile(test_dir_ / \"file2.txt\");\n  CreateFile(test_dir_ / \"sub\" / \"nested.txt\");\n\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n\n  auto discovered = DrainInitialScan(queue);\n  std::unordered_set<std::string> relative_paths;\n  for (const auto& path : discovered) {\n    relative_paths.insert(RelativeTo(test_dir_, path));\n  }\n\n  EXPECT_EQ(relative_paths.size(), 3u);\n  EXPECT_TRUE(relative_paths.count(\"file1.txt\"));\n  EXPECT_TRUE(relative_paths.count(\"file2.txt\"));\n  EXPECT_TRUE(relative_paths.count(\"sub/nested.txt\"));\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, InitialScanSkipsHiddenEntries) {\n  CreateFile(test_dir_ / \"visible.txt\");\n  CreateFile(test_dir_ / \".hidden_file\");\n  CreateFile(test_dir_ / \".hidden_dir\" / \"nested.txt\");\n  CreateFile(test_dir_ / \"visible_dir\" / \"child.txt\");\n\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n\n  auto discovered = DrainInitialScan(queue);\n  std::unordered_set<std::string> relative_paths;\n  for (const auto& path : discovered) {\n    relative_paths.insert(RelativeTo(test_dir_, path));\n  }\n\n  EXPECT_TRUE(relative_paths.count(\"visible.txt\"));\n  EXPECT_TRUE(relative_paths.count(\"visible_dir/child.txt\"));\n  EXPECT_FALSE(relative_paths.count(\".hidden_file\"));\n  EXPECT_FALSE(relative_paths.count(\".hidden_dir/nested.txt\"));\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, DetectsNewVisibleFile) {\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n  DrainInitialScan(queue);\n\n  CreateFile(test_dir_ / \"new_file.txt\");\n\n  auto message = AwaitNextFile(queue);\n  EXPECT_EQ(RelativeTo(test_dir_, message.filepath), \"new_file.txt\");\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, DetectsFilesInPreExistingSubdirectory) {\n  auto subdir = test_dir_ / \"subdir\";\n  CreateDirectory(subdir);\n\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n  DrainInitialScan(queue);\n\n  CreateFile(subdir / \"from_subdir.txt\");\n\n  auto message = AwaitNextFile(queue);\n  EXPECT_EQ(RelativeTo(test_dir_, message.filepath), \"subdir/from_subdir.txt\");\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, IgnoresHiddenFileEvents) {\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n  DrainInitialScan(queue);\n\n  CreateFile(test_dir_ / \".hidden_event.txt\");\n  CreateFile(test_dir_ / \"visible_after_hidden.txt\");\n\n  auto message = AwaitNextFile(queue);\n  EXPECT_EQ(RelativeTo(test_dir_, message.filepath),\n            \"visible_after_hidden.txt\");\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, SkipsHiddenDirectoryRecursion) {\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n  DrainInitialScan(queue);\n\n  CreateDirectory(test_dir_ / \".hidden_dir\");\n  CreateFile(test_dir_ / \".hidden_dir\" / \"inner.txt\");\n  CreateFile(test_dir_ / \"outer.txt\");\n\n  auto message = AwaitNextFile(queue);\n  EXPECT_EQ(RelativeTo(test_dir_, message.filepath), \"outer.txt\");\n\n  provider.Stop();\n}\n\nTEST_F(FilePathProviderTest, HandlesEmptyDirectory) {\n  FilePathProvider provider(Config());\n  provider.Start();\n  auto* queue = provider.output_queue();\n\n  auto discovered = DrainInitialScan(queue);\n  EXPECT_TRUE(discovered.empty());\n\n  provider.Stop();\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/join_stage.cc",
    "content": "#include \"loader/stages/join_stage.h\"\n\n#include <absl/log/log.h>\n\n#include \"loader/data_loader_metrics.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n\nnamespace lczero {\nnamespace training {\n\ntemplate <typename T>\nJoinStage<T>::JoinStage(const JoinPositionsConfig& config)\n    : SingleOutputStage<T>(config.output()) {}\n\ntemplate <typename T>\nJoinStage<T>::~JoinStage() {\n  Stop();\n}\n\ntemplate <typename T>\nvoid JoinStage<T>::SetInputs(absl::Span<QueueBase* const> inputs) {\n  input_queues_.clear();\n  for (QueueBase* base_queue : inputs) {\n    auto* typed_queue = dynamic_cast<Queue<T>*>(base_queue);\n    if (!typed_queue) throw std::runtime_error(\"Input queue type mismatch\");\n    input_queues_.push_back(typed_queue);\n  }\n}\n\ntemplate <typename T>\nvoid JoinStage<T>::Start() {\n  thread_contexts_.clear();\n  thread_pool_ = std::make_unique<ThreadPool>(input_queues_.size());\n  for (size_t i = 0; i < input_queues_.size(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n  for (size_t i = 0; i < input_queues_.size(); ++i) {\n    thread_pool_->Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, input_queues_[i], thread_contexts_[i].get());\n    });\n  }\n}\n\ntemplate <typename T>\nvoid JoinStage<T>::Worker(std::stop_token stop_token, Queue<T>* input_queue,\n                          ThreadContext* context) {\n  auto producer = this->output_queue()->CreateProducer();\n  try {\n    while (true) {\n      auto item = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return input_queue->Get(stop_token);\n      }();\n      producer.Put(std::move(item), stop_token);\n    }\n  } catch (const QueueClosedException&) {\n  }\n}\n\ntemplate <typename T>\nvoid JoinStage<T>::Stop() {\n  if (!thread_pool_ || thread_pool_->stop_token().stop_requested()) return;\n  LOG(INFO) << \"Stopping JoinStage.\";\n  thread_pool_->Shutdown();\n  this->output_queue()->Close();\n}\n\ntemplate <typename T>\nStageMetricProto JoinStage<T>::FlushMetrics() {\n  StageMetricProto metrics;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *metrics.add_load_metrics() = std::move(aggregated_load);\n  *metrics.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *this->output_queue());\n\n  return metrics;\n}\n\n// Explicit template instantiation for FrameType.\ntemplate class JoinStage<FrameType>;\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/join_stage.h",
    "content": "// ABOUTME: Stage that joins multiple input queues into a single output.\n// ABOUTME: Spawns one thread per input to read and forward items.\n#pragma once\n\n#include <atomic>\n#include <memory>\n#include <stop_token>\n#include <vector>\n\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/frame_type.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Template stage that joins multiple input queues into a single output.\n// Spawns one thread per input to consume and forward items.\ntemplate <typename T>\nclass JoinStage : public SingleOutputStage<T> {\n public:\n  using OutputType = T;\n\n  explicit JoinStage(const JoinPositionsConfig& config);\n  ~JoinStage();\n\n  void Start() override;\n  void Stop() override;\n  StageMetricProto FlushMetrics() override;\n  void SetInputs(absl::Span<QueueBase* const> inputs) override;\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, Queue<T>* input_queue,\n              ThreadContext* context);\n\n  std::vector<Queue<T>*> input_queues_;\n  // thread_contexts_ must be declared before thread_pool_ to ensure\n  // thread_pool_ is destroyed first (stopping threads before contexts).\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  std::unique_ptr<ThreadPool> thread_pool_;\n};\n\nusing JoinPositions = JoinStage<FrameType>;\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/join_stage_test.cc",
    "content": "#include \"loader/stages/join_stage.h\"\n\n#include <memory>\n#include <vector>\n\n#include \"absl/container/flat_hash_set.h\"\n#include \"gtest/gtest.h\"\n#include \"libs/lc0/src/trainingdata/trainingdata_v6.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nclass JoinStageTest : public ::testing::Test {\n protected:\n  void SetUp() override { config_.mutable_output()->set_queue_capacity(100); }\n\n  FrameType CreateTestFrame(uint32_t version) {\n    FrameType frame{};\n    frame.version = version;\n    frame.input_format = 3;\n    frame.root_q = 0.5f;\n    return frame;\n  }\n\n  JoinPositionsConfig config_;\n};\n\nTEST_F(JoinStageTest, JoinsTwoInputs) {\n  auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);\n  auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);\n\n  JoinPositions join_stage(config_);\n  join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()});\n  join_stage.Start();\n\n  auto producer_1 = input_queue_1->CreateProducer();\n  auto producer_2 = input_queue_2->CreateProducer();\n\n  producer_1.Put(CreateTestFrame(1));\n  producer_1.Put(CreateTestFrame(2));\n  producer_2.Put(CreateTestFrame(3));\n  producer_2.Put(CreateTestFrame(4));\n\n  absl::flat_hash_set<uint32_t> received_versions;\n  for (int i = 0; i < 4; ++i) {\n    auto frame = join_stage.output_queue()->Get();\n    received_versions.insert(frame.version);\n  }\n\n  producer_1.Close();\n  producer_2.Close();\n  join_stage.Stop();\n\n  EXPECT_EQ(received_versions.size(), 4u);\n  EXPECT_TRUE(received_versions.contains(1));\n  EXPECT_TRUE(received_versions.contains(2));\n  EXPECT_TRUE(received_versions.contains(3));\n  EXPECT_TRUE(received_versions.contains(4));\n}\n\nTEST_F(JoinStageTest, JoinsThreeInputs) {\n  auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);\n  auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);\n  auto input_queue_3 = std::make_unique<Queue<FrameType>>(10);\n\n  JoinPositions join_stage(config_);\n  join_stage.SetInputs(\n      {input_queue_1.get(), input_queue_2.get(), input_queue_3.get()});\n  join_stage.Start();\n\n  auto producer_1 = input_queue_1->CreateProducer();\n  auto producer_2 = input_queue_2->CreateProducer();\n  auto producer_3 = input_queue_3->CreateProducer();\n\n  producer_1.Put(CreateTestFrame(10));\n  producer_2.Put(CreateTestFrame(20));\n  producer_3.Put(CreateTestFrame(30));\n\n  absl::flat_hash_set<uint32_t> received_versions;\n  for (int i = 0; i < 3; ++i) {\n    auto frame = join_stage.output_queue()->Get();\n    received_versions.insert(frame.version);\n  }\n\n  producer_1.Close();\n  producer_2.Close();\n  producer_3.Close();\n  join_stage.Stop();\n\n  EXPECT_EQ(received_versions.size(), 3u);\n  EXPECT_TRUE(received_versions.contains(10));\n  EXPECT_TRUE(received_versions.contains(20));\n  EXPECT_TRUE(received_versions.contains(30));\n}\n\nTEST_F(JoinStageTest, HandlesEmptyInputs) {\n  auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);\n  auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);\n\n  JoinPositions join_stage(config_);\n  join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()});\n  join_stage.Start();\n\n  auto producer_1 = input_queue_1->CreateProducer();\n  auto producer_2 = input_queue_2->CreateProducer();\n\n  producer_1.Close();\n  producer_2.Close();\n\n  auto maybe_frame = join_stage.output_queue()->MaybeGet();\n  EXPECT_FALSE(maybe_frame.has_value());\n\n  join_stage.Stop();\n}\n\nTEST_F(JoinStageTest, FlushesMetrics) {\n  auto input_queue = std::make_unique<Queue<FrameType>>(10);\n\n  JoinPositions join_stage(config_);\n  join_stage.SetInputs({input_queue.get()});\n  join_stage.Start();\n\n  auto producer = input_queue->CreateProducer();\n  producer.Put(CreateTestFrame(1));\n\n  auto frame = join_stage.output_queue()->Get();\n  EXPECT_EQ(frame.version, 1u);\n\n  producer.Close();\n  join_stage.Stop();\n\n  auto metrics = join_stage.FlushMetrics();\n  EXPECT_EQ(metrics.load_metrics_size(), 1);\n  EXPECT_EQ(metrics.queue_metrics_size(), 1);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/position_sampling.cc",
    "content": "#include \"loader/stages/position_sampling.h\"\n\n#include <cmath>\n\nnamespace lczero {\nnamespace training {\n\nfloat ComputePositionSamplingWeight(const FrameType& frame,\n                                    const PositionSamplingConfig& config) {\n  if (!config.has_diff_focus_q_weight() && !config.has_diff_focus_pol_scale()) {\n    return config.default_weight();\n  }\n  if (std::isnan(frame.orig_q)) return config.default_weight();\n  const float diff_q = std::abs(frame.best_q - frame.orig_q);\n  const float q_weight = config.diff_focus_q_weight();\n  const float pol_scale = config.diff_focus_pol_scale();\n  const float total =\n      (q_weight * diff_q + frame.policy_kld) / (q_weight + pol_scale);\n  return std::min(\n      std::pow(total * config.diff_focus_alpha() + config.diff_focus_beta(),\n               config.diff_focus_gamma()),\n      config.diff_focus_tau());\n}\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/loader/stages/position_sampling.h",
    "content": "#pragma once\n\n#include \"loader/frame_type.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nfloat ComputePositionSamplingWeight(const FrameType& frame,\n                                    const PositionSamplingConfig& config);\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/loader/stages/shuffling_chunk_pool.cc",
    "content": "#include \"loader/stages/shuffling_chunk_pool.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/base/thread_annotations.h>\n#include <absl/log/log.h>\n#include <absl/random/random.h>\n#include <absl/synchronization/mutex.h>\n\n#include <algorithm>\n#include <cassert>\n#include <chrono>\n#include <cmath>\n#include <cstring>\n#include <filesystem>\n#include <limits>\n#include <stdexcept>\n#include <thread>\n#include <utility>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/position_sampling.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\nthread_local absl::BitGen ShufflingChunkPool::bitgen_{absl::MakeSeedSeq()};\n\nShufflingChunkPool::ShufflingChunkPool(const ShufflingChunkPoolConfig& config)\n    : primary_output_name_(config.output().name()),\n      primary_output_queue_(\n          config.output().queue_capacity(),\n          ToOverflowBehavior(config.output().overflow_behavior())),\n      chunk_pool_size_(config.chunk_pool_size()),\n      config_(config),\n      source_ingestion_pool_(config.source_ingestion_threads(),\n                             ThreadPoolOptions{}, stop_source_),\n      chunk_loading_pool_(config.chunk_loading_threads(), ThreadPoolOptions{},\n                          stop_source_),\n      caching_pool_(config.has_cachehit_output() ? config.caching_threads() : 0,\n                    ThreadPoolOptions{}, stop_source_) {\n  if (config.has_cachehit_output()) {\n    cachehit_output_name_ = config.cachehit_output().name();\n    cachehit_output_queue_.emplace(\n        config.cachehit_output().queue_capacity(),\n        ToOverflowBehavior(config.cachehit_output().overflow_behavior()));\n    if (primary_output_name_ == *cachehit_output_name_) {\n      throw std::runtime_error(absl::StrCat(\n          \"ShufflingChunkPool output names must be different, got: '\",\n          primary_output_name_, \"'\"));\n    }\n  }\n  LOG(INFO) << \"Initializing ShufflingChunkPool with pool size \"\n            << config.chunk_pool_size();\n}\n\nShufflingChunkPool::~ShufflingChunkPool() { Stop(); }\n\nvoid ShufflingChunkPool::SetInputs(absl::Span<QueueBase* const> inputs) {\n  if (inputs.size() != 1 && inputs.size() != 2) {\n    throw std::runtime_error(absl::StrCat(\n        \"ShufflingChunkPool expects 1 or 2 inputs, got \", inputs.size()));\n  }\n  if (inputs.size() == 2 && !cachehit_output_queue_.has_value()) {\n    throw std::runtime_error(\n        \"ShufflingChunkPool received 2 inputs but cachehit_output is not \"\n        \"configured\");\n  }\n  if (inputs.size() == 1 && cachehit_output_queue_.has_value()) {\n    throw std::runtime_error(\n        \"ShufflingChunkPool has cachehit_output configured but received only \"\n        \"1 input\");\n  }\n  primary_input_queue_ = dynamic_cast<Queue<ChunkSourceWithPhase>*>(inputs[0]);\n  if (!primary_input_queue_) {\n    throw std::runtime_error(\"ShufflingChunkPool primary input type mismatch\");\n  }\n  if (inputs.size() == 2) {\n    cache_request_queue_ = dynamic_cast<Queue<CacheRequest>*>(inputs[1]);\n    if (!cache_request_queue_) {\n      throw std::runtime_error(\n          \"ShufflingChunkPool cache request input type mismatch\");\n    }\n  }\n}\n\nQueueBase* ShufflingChunkPool::GetOutput(std::string_view name) {\n  if (name == primary_output_name_) return &primary_output_queue_;\n  if (cachehit_output_name_.has_value() && name == *cachehit_output_name_) {\n    return &*cachehit_output_queue_;\n  }\n  std::string available = absl::StrCat(\"'\", primary_output_name_, \"'\");\n  if (cachehit_output_name_.has_value()) {\n    absl::StrAppend(&available, \", '\", *cachehit_output_name_, \"'\");\n  }\n  throw std::runtime_error(absl::StrCat(\"ShufflingChunkPool unknown output '\",\n                                        name,\n                                        \"'. Available outputs: \", available));\n}\n\nvoid ShufflingChunkPool::Start() {\n  LOG(INFO) << \"Starting ShufflingChunkPool initialization thread.\";\n  initialization_thread_ = std::jthread([this]() {\n    try {\n      LOG(INFO) << \"Starting ShufflingChunkPool with pool size \"\n                << config_.chunk_pool_size();\n      std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources =\n          InitializeChunkSources();\n      ProcessInputFiles(std::move(uninitialized_sources));\n\n      // Start input processing worker that continuously processes new files.\n      for (size_t i = 0; i < source_ingestion_pool_.num_threads(); ++i) {\n        auto* context =\n            source_ingestion_thread_contexts_\n                .emplace_back(std::make_unique<SourceIngestionThreadContext>())\n                .get();\n        source_ingestion_pool_.Enqueue(\n            [this, context](std::stop_token stop_token) {\n              SourceIngestionWorker(stop_token, context);\n            });\n      }\n\n      // Start output workers after everything is fully initialized.\n      LOG(INFO) << \"ShufflingChunkPool initialization done, starting workers\";\n      for (size_t i = 0; i < chunk_loading_pool_.num_threads(); ++i) {\n        auto* context =\n            chunk_loading_thread_contexts_\n                .emplace_back(std::make_unique<ChunkLoadingThreadContext>())\n                .get();\n        chunk_loading_pool_.Enqueue(\n            [this, context](std::stop_token stop_token) {\n              OutputWorker(stop_token, context);\n            });\n      }\n\n      // Start caching workers if configured.\n      if (cachehit_output_queue_.has_value()) {\n        for (size_t i = 0; i < caching_pool_.num_threads(); ++i) {\n          auto* context =\n              caching_thread_contexts_\n                  .emplace_back(std::make_unique<CachingThreadContext>())\n                  .get();\n          caching_pool_.Enqueue([this, context](std::stop_token stop_token) {\n            CachingWorker(stop_token, context);\n          });\n        }\n      }\n    } catch (const QueueClosedException&) {\n      LOG(INFO) << \"ShufflingChunkPool initialization interrupted, input \"\n                   \"queue closed.\";\n      output_queue()->Close();\n    } catch (const std::exception& e) {\n      LOG(ERROR) << \"ShufflingChunkPool initialization failed: \" << e.what();\n      output_queue()->Close();\n    }\n  });\n}\n\nvoid ShufflingChunkPool::Stop() {\n  if (stop_source_.stop_requested()) return;\n\n  LOG(INFO) << \"Stopping ShufflingChunkPool.\";\n  stop_source_.request_stop();\n  if (initialization_thread_.joinable()) {\n    initialization_thread_.request_stop();\n    initialization_thread_.join();\n  }\n\n  source_ingestion_pool_.Shutdown();\n  chunk_loading_pool_.Shutdown();\n  if (cachehit_output_queue_) caching_pool_.Shutdown();\n  output_queue()->Close();\n  if (cachehit_output_queue_) cachehit_output_queue_->Close();\n  LOG(INFO) << \"ShufflingChunkPool stopped.\";\n}\n\nstd::vector<std::unique_ptr<ChunkSource>>\nShufflingChunkPool::InitializeChunkSources() {\n  std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources;\n\n  // Read from input queue until kInitialScanComplete.\n  while (true) {\n    auto chunk_source_with_phase = input_queue()->Get();\n\n    if (chunk_source_with_phase.message_type ==\n        FilePathProvider::MessageType::kInitialScanComplete) {\n      LOG(INFO)\n          << \"ShufflingChunkPool received initial scan completion marker.\";\n      break;\n    }\n\n    if (chunk_source_with_phase.message_type ==\n        FilePathProvider::MessageType::kFile) {\n      // Add ChunkSource to uninitialized sources.\n      uninitialized_sources.push_back(\n          std::move(chunk_source_with_phase.source));\n    }\n  }\n\n  LOG(INFO) << \"ShufflingChunkPool initial directory walk produced \"\n            << uninitialized_sources.size() << \" chunk source candidate(s).\";\n\n  // Sort in descending order (newest first).\n  std::sort(uninitialized_sources.begin(), uninitialized_sources.end(),\n            [](const auto& a, const auto& b) {\n              return a->GetChunkSortKey() > b->GetChunkSortKey();\n            });\n  std::atomic<size_t> total_chunks = 0;\n  size_t sources_to_keep = 0;\n\n  // Process sources sequentially until we have enough chunks.\n  std::string current_anchor;\n  {\n    absl::MutexLock lock(&anchor_mutex_);\n    current_anchor = anchor_;\n  }\n\n  for (auto& source : uninitialized_sources) {\n    if (output_queue()->IsClosed()) {\n      LOG(INFO) << \"Output queue closed, stopping source ingestion.\";\n      break;\n    }\n    if (total_chunks >= chunk_pool_size_) break;\n\n    // Count chunks immediately; constructors have already prepared metadata.\n    const size_t chunk_count = source->GetChunkCount();\n    total_chunks += chunk_count;\n\n    // Count chunks since anchor during initial load.\n    if (source->GetChunkSortKey() > current_anchor) {\n      chunks_since_anchor_ += chunk_count;\n    }\n\n    LOG_EVERY_N_SEC(INFO, 4) << \"Loaded so far: \" << total_chunks.load()\n                             << \"; new: \" << chunks_since_anchor_;\n    ++sources_to_keep;\n  }\n\n  LOG(INFO) << \"ShufflingChunkPool indexed \" << total_chunks.load()\n            << \" chunk(s) across \" << sources_to_keep\n            << \" source(s) during startup.\";\n\n  if (total_chunks < chunk_pool_size_ && !output_queue()->IsClosed()) {\n    LOG(ERROR) << \"ShufflingChunkPool startup chunk requirement not met: \"\n               << total_chunks.load() << \" < \" << chunk_pool_size_;\n  }\n\n  // Trim the vector to only keep the sources we need.\n  uninitialized_sources.resize(sources_to_keep);\n  return uninitialized_sources;\n}\n\nvoid ShufflingChunkPool::ProcessInputFiles(\n    std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources) {\n  // Initialize chunk sources from the initial scan.\n  size_t initial_window_sources = 0;\n  size_t initial_total_chunks = 0;\n  {\n    absl::MutexLock lock(&chunk_sources_mutex_);\n    size_t start_chunk_index = 0;\n    // Newest sources first, so we add in reverse order.\n    std::for_each(uninitialized_sources.rbegin(), uninitialized_sources.rend(),\n                  [this, &start_chunk_index](auto& source) {\n                    const size_t count = source->GetChunkCount();\n                    auto item = std::make_shared<ChunkSourceItem>();\n                    item->start_chunk_index = start_chunk_index;\n                    item->source = std::move(source);\n                    item->use_counts = std::vector<uint16_t>(count, 0);\n                    item->weight = std::vector<float>(count, -1.0f);\n                    item->cache = std::vector<std::unique_ptr<CacheNode>>(\n                        cachehit_output_queue_.has_value() ? count : 0);\n                    chunk_sources_.push_back(std::move(item));\n                    start_chunk_index +=\n                        chunk_sources_.back()->source->GetChunkCount();\n                  });\n\n    // Initialize stream shuffler with the initial bounds.\n    if (!chunk_sources_.empty()) {\n      size_t total_chunks = chunk_sources_.back()->start_chunk_index +\n                            chunk_sources_.back()->source->GetChunkCount();\n      // Set bounds to provide the last chunk_pool_size_ chunks.\n      size_t lower_bound =\n          total_chunks > chunk_pool_size_ ? total_chunks - chunk_pool_size_ : 0;\n      stream_shuffler_.SetLowerBound(lower_bound);\n      stream_shuffler_.SetUpperBound(total_chunks);\n      initial_total_chunks = total_chunks;\n    }\n    initial_window_sources = chunk_sources_.size();\n  }\n\n  LOG(INFO) << \"ShufflingChunkPool initial window ready with \"\n            << initial_window_sources << \" source(s) totaling \"\n            << initial_total_chunks << \" chunk(s).\";\n\n  // Log anchor and sources after initial scan completion.\n  {\n    absl::MutexLock anchor_lock(&anchor_mutex_);\n    LOG(INFO) << \"Current anchor: '\" << anchor_ << \"'\";\n\n    absl::MutexLock sources_lock(&chunk_sources_mutex_);\n    std::vector<std::shared_ptr<ChunkSourceItem>> sources_after_anchor;\n    for (const auto& item : chunk_sources_) {\n      if (item->source->GetChunkSortKey() > anchor_) {\n        sources_after_anchor.push_back(item);\n      }\n    }\n\n    LOG(INFO) << sources_after_anchor.size()\n              << \" chunk source(s) after anchor, \" << chunks_since_anchor_\n              << \" total chunks since anchor\";\n\n    const size_t to_log = std::min(sources_after_anchor.size(), size_t(20));\n    for (size_t i = 0; i < to_log; ++i) {\n      LOG(INFO) << \"  Source [\" << (i + 1) << \"/\" << sources_after_anchor.size()\n                << \"]: key='\"\n                << sources_after_anchor[i]->source->GetChunkSortKey()\n                << \"', chunks=\"\n                << sources_after_anchor[i]->source->GetChunkCount();\n    }\n  }\n\n  if (initial_total_chunks == 0) {\n    throw std::runtime_error(\n        \"ShufflingChunkPool requires at least one chunk during startup.\");\n  }\n}\n\nvoid ShufflingChunkPool::SourceIngestionWorker(\n    std::stop_token stop_token, SourceIngestionThreadContext* context) {\n  try {\n    while (true) {\n      auto chunk_source_with_phase = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return input_queue()->Get(stop_token);\n      }();\n\n      if (chunk_source_with_phase.message_type ==\n          FilePathProvider::MessageType::kFile) {\n        // Ingest the new chunk source.\n        auto source = std::move(chunk_source_with_phase.source);\n        size_t chunk_count = source->GetChunkCount();\n        absl::MutexLock lock(&chunk_sources_mutex_);\n        chunks_since_anchor_ += chunk_count;\n        AddNewChunkSource(std::move(source));\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"SourceIngestionWorker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"SourceIngestionWorker stopping, request cancelled.\";\n  }\n}\n\nvoid ShufflingChunkPool::OutputWorker(std::stop_token stop_token,\n                                      ChunkLoadingThreadContext* context) {\n  // Create a local producer for this worker\n  auto primary_producer = output_queue()->CreateProducer();\n  std::optional<decltype(cachehit_output_queue_->CreateProducer())>\n      cachehit_producer;\n  if (cachehit_output_queue_.has_value()) {\n    cachehit_producer.emplace(cachehit_output_queue_->CreateProducer());\n  }\n\n  try {\n    while (true) {\n      auto result = GetNextChunkData();\n      if (!result) {\n        if (output_queue()->IsClosed()) break;\n        continue;\n      }\n      LoadMetricPauser pauser(context->load_metric_updater);\n      if (std::holds_alternative<TrainingChunk>(*result)) {\n        primary_producer.Put(std::move(std::get<TrainingChunk>(*result)),\n                             stop_token);\n      } else {\n        cachehit_producer->Put(std::move(std::get<FrameType>(*result)),\n                               stop_token);\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"OutputWorker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"OutputWorker stopping, request cancelled.\";\n  } catch (const std::exception& e) {\n    LOG(FATAL) << \"OutputWorker encountered an error: \" << e.what();\n  }\n}\n\nvoid ShufflingChunkPool::CachingWorker(std::stop_token stop_token,\n                                       CachingThreadContext* context) {\n  constexpr double kTheta = 0.99;\n  double reminder = 0.0;\n  double exponential_avg_probability = 1.0;\n  try {\n    while (true) {\n      auto cache_request = [&]() {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        return cache_request_queue_->Get(stop_token);\n      }();\n\n      std::shared_ptr<ChunkSourceItem> source_item;\n      float max_weight = 0.0f;\n      size_t local_index = 0;\n      {\n        absl::MutexLock lock(&chunk_sources_mutex_);\n\n        // Find the chunk source containing this global index.\n        auto it = absl::c_lower_bound(\n            chunk_sources_, cache_request.global_index,\n            [](const auto& item, size_t chunk_idx) {\n              return item->start_chunk_index + item->source->GetChunkCount() <=\n                     chunk_idx;\n            });\n\n        if (it == chunk_sources_.end() ||\n            cache_request.global_index < (*it)->start_chunk_index) {\n          chunk_source_not_found_.fetch_add(1, std::memory_order_acq_rel);\n          continue;\n        }\n\n        source_item = *it;\n        max_weight = max_weight_;\n        local_index = cache_request.global_index - source_item->start_chunk_index;\n      }\n\n      absl::MutexLock item_lock(&source_item->mutex);\n      assert(local_index < source_item->use_counts.size());\n\n      // Check use_count match.\n      if (source_item->use_counts[local_index] != cache_request.next_use) {\n        mismatched_use_counts_.fetch_add(1, std::memory_order_acq_rel);\n        continue;\n      }\n\n      // Compute how many positions to cache.\n      const float weight = source_item->weight[local_index];\n      assert(weight >= 0.0f);\n      const double probability = ComputeHanseProbability(weight, max_weight);\n      exponential_avg_probability =\n          exponential_avg_probability * (1.0 - kTheta) + probability * kTheta;\n      const double n = (probability * config_.position_cache_size() /\n                        chunk_pool_size_ / exponential_avg_probability) +\n                       reminder;\n      reminder = n - std::floor(n);\n      const size_t positions_to_cache = static_cast<size_t>(std::floor(n));\n      // Traverse and extend the cache chain.\n      std::unique_ptr<CacheNode>* current = &source_item->cache[local_index];\n      for (size_t i = 0; i < positions_to_cache; ++i) {\n        if (*current) {\n          current = &(*current)->next;\n          continue;\n        }\n        if (i >= cache_request.items.size()) break;\n        auto node = std::make_unique<CacheNode>();\n        node->frame = cache_request.items[i];\n        *current = std::move(node);\n        current = &(*current)->next;\n        newly_cached_.fetch_add(1, std::memory_order_acq_rel);\n        cached_positions_.fetch_add(1, std::memory_order_acq_rel);\n      }\n\n      const size_t dropped =\n          cache_request.items.size() > positions_to_cache\n              ? cache_request.items.size() - positions_to_cache\n              : 0;\n      dropped_cache_positions_.fetch_add(dropped, std::memory_order_acq_rel);\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"CachingWorker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"CachingWorker stopping, request cancelled.\";\n  }\n}\n\nstruct ShufflingChunkPool::ChunkData {\n  std::vector<FrameType> data;\n  std::string sort_key;\n  size_t local_index = 0;\n  size_t global_index = 0;\n  uint32_t use_count = 0;\n  std::shared_ptr<ChunkSourceItem> source_item;\n};\n\nstd::optional<std::variant<TrainingChunk, FrameType>>\nShufflingChunkPool::GetNextChunkData() {\n  while (true) {\n    ChunkData chunk_data;\n    const ChunkStatus status = GetChunkInfo(chunk_data);\n    if (status == ChunkStatus::kEnd) return std::nullopt;\n    if (status == ChunkStatus::kRetry) continue;\n\n    const bool hanse_enabled = config_.hanse_sampling_threshold() > 0;\n    if (hanse_enabled && !HanseAccept(chunk_data)) continue;\n\n    {\n      absl::MutexLock lock(&chunk_data.source_item->mutex);\n\n      assert(chunk_data.source_item->use_counts.size() > chunk_data.local_index);\n      chunk_data.use_count =\n          chunk_data.source_item->use_counts[chunk_data.local_index]++;\n\n      if (cachehit_output_queue_.has_value()) {\n        auto& cache_chain = chunk_data.source_item->cache[chunk_data.local_index];\n        if (cache_chain) {\n          cache_hits_.fetch_add(1, std::memory_order_acq_rel);\n          cached_positions_.fetch_sub(1, std::memory_order_acq_rel);\n          FrameType cached_frame = cache_chain->frame;\n          cache_chain = std::move(cache_chain->next);\n          return cached_frame;\n        }\n        cache_misses_.fetch_add(1, std::memory_order_acq_rel);\n      }\n    }\n\n    if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) continue;\n\n    TrainingChunk chunk;\n    chunk.sort_key = std::move(chunk_data.sort_key);\n    chunk.index_within_sort_key = chunk_data.local_index;\n    chunk.use_count = chunk_data.use_count;\n    chunk.global_index = chunk_data.global_index;\n    chunk.frames = std::move(chunk_data.data);\n\n    return chunk;\n  }\n}\n\nbool ShufflingChunkPool::LoadChunkData(ChunkData& chunk_data) {\n  std::optional<std::vector<FrameType>> data =\n      chunk_data.source_item->source->GetChunkData(chunk_data.local_index);\n\n  if (!data || data->empty()) {\n    absl::MutexLock lock(&chunk_data.source_item->mutex);\n    chunk_data.source_item->dropped_chunks.insert(chunk_data.local_index);\n    dropped_chunks_metric_.fetch_add(1, std::memory_order_acq_rel);\n    return false;\n  }\n\n  chunk_data.data = std::move(*data);\n  return true;\n}\n\nShufflingChunkPool::ChunkStatus ShufflingChunkPool::GetChunkInfo(\n    ChunkData& out_chunk_data) {\n  std::shared_ptr<ChunkSourceItem> source_item;\n  {\n    absl::MutexLock lock(&chunk_sources_mutex_);\n    std::optional<size_t> chunk_index = stream_shuffler_.GetNextItem();\n\n    if (!chunk_index && !chunk_sources_.empty()) {\n      size_t total_chunks = chunk_sources_.back()->start_chunk_index +\n                            chunk_sources_.back()->source->GetChunkCount();\n      size_t lower_bound = total_chunks > chunk_pool_size_\n                               ? total_chunks - chunk_pool_size_\n                               : chunk_sources_.front()->start_chunk_index;\n      stream_shuffler_.Reset(lower_bound, total_chunks);\n      reshuffles_.fetch_add(1, std::memory_order_acq_rel);\n      chunk_index = stream_shuffler_.GetNextItem();\n    }\n\n    if (!chunk_index) return ChunkStatus::kEnd;\n\n    auto it =\n        absl::c_lower_bound(chunk_sources_, *chunk_index,\n                            [](const auto& item, size_t chunk_idx) {\n                              return item->start_chunk_index +\n                                         item->source->GetChunkCount() <=\n                                     chunk_idx;\n                            });\n\n    if (ABSL_PREDICT_FALSE(it == chunk_sources_.end() ||\n                           *chunk_index < (*it)->start_chunk_index)) {\n      LOG(WARNING) << \"Chunk index \" << *chunk_index\n                   << \" out of range for available chunk sources.\";\n      return ChunkStatus::kRetry;\n    }\n\n    source_item = *it;\n    out_chunk_data.local_index = *chunk_index - source_item->start_chunk_index;\n    out_chunk_data.sort_key = source_item->source->GetChunkSortKey();\n    out_chunk_data.global_index = *chunk_index;\n  }\n\n  {\n    absl::MutexLock lock(&source_item->mutex);\n    if (source_item->dropped_chunks.contains(out_chunk_data.local_index)) {\n      return ChunkStatus::kRetry;\n    }\n  }\n\n  out_chunk_data.source_item = std::move(source_item);\n  return ChunkStatus::kOk;\n}\n\ndouble ShufflingChunkPool::ComputeHanseProbability(float weight,\n                                                   float max_weight) const {\n  if (max_weight <= 0.0f) return 1.0;\n  return std::pow(weight / max_weight, config_.hanse_sampling_gamma());\n}\n\nfloat ShufflingChunkPool::ComputeChunkWeight(\n    absl::Span<const FrameType> frames) const {\n  return absl::c_accumulate(frames, 0.0f, [this](float sum, const auto& frame) {\n    return sum +\n           ComputePositionSamplingWeight(frame, config_.position_sampling());\n  });\n}\n\nbool ShufflingChunkPool::HanseAccept(ChunkData& chunk_data) {\n  assert(chunk_data.source_item);\n  float weight = -1.0f;\n  {\n    absl::MutexLock lock(&chunk_data.source_item->mutex);\n    assert(chunk_data.source_item->weight.size() > chunk_data.local_index);\n    weight = chunk_data.source_item->weight[chunk_data.local_index];\n  }\n\n  if (weight < 0.0f) {\n    if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) return false;\n    weight = ComputeChunkWeight(chunk_data.data);\n    {\n      absl::MutexLock pool_lock(&chunk_sources_mutex_);\n      max_weight_ = std::max(max_weight_, weight);\n      AddSample(chunk_weight_stats_, static_cast<double>(weight));\n    }\n    {\n      absl::MutexLock lock(&chunk_data.source_item->mutex);\n      const float cached_weight =\n          chunk_data.source_item->weight[chunk_data.local_index];\n      if (cached_weight < 0.0f) {\n        chunk_data.source_item->weight[chunk_data.local_index] = weight;\n      } else {\n        weight = cached_weight;\n      }\n    }\n    hanse_cache_misses_.fetch_add(1, std::memory_order_acq_rel);\n  } else {\n    hanse_cache_hits_.fetch_add(1, std::memory_order_acq_rel);\n  }\n\n  float max_weight = 0.0f;\n  {\n    absl::MutexLock lock(&chunk_sources_mutex_);\n    max_weight = max_weight_;\n  }\n  const double p = ComputeHanseProbability(weight, max_weight);\n  const double u = absl::Uniform<double>(bitgen_, 0.0, 1.0);\n  if (u >= p) {\n    hanse_rejected_.fetch_add(1, std::memory_order_acq_rel);\n    return false;\n  }\n  return true;\n}\n\nvoid ShufflingChunkPool::AddNewChunkSource(std::unique_ptr<ChunkSource> source)\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_) {\n  // Add new chunk source to the end of the deque.\n  size_t old_upper_bound = 0;\n  if (!chunk_sources_.empty()) {\n    const auto& last_source = chunk_sources_.back();\n    old_upper_bound =\n        last_source->start_chunk_index + last_source->source->GetChunkCount();\n  }\n\n  size_t count = source->GetChunkCount();\n  auto item = std::make_shared<ChunkSourceItem>();\n  item->start_chunk_index = old_upper_bound;\n  item->source = std::move(source);\n  item->use_counts = std::vector<uint16_t>(count, 0);\n  item->weight = std::vector<float>(count, -1.0f);\n  item->cache =\n      std::vector<std::unique_ptr<CacheNode>>(cachehit_output_queue_.has_value()\n                                                  ? count\n                                                  : 0);\n  chunk_sources_.push_back(std::move(item));\n\n  // Calculate current window bounds.\n  size_t new_upper_bound = chunk_sources_.back()->start_chunk_index +\n                           chunk_sources_.back()->source->GetChunkCount();\n\n  // Remove old chunks if window exceeds chunk_pool_size_.\n  while (!chunk_sources_.empty() && chunk_sources_.size() > 1) {\n    size_t window_start = chunk_sources_.front()->start_chunk_index +\n                          chunk_sources_.front()->source->GetChunkCount();\n    size_t window_size = new_upper_bound - window_start;\n\n    if (window_size < chunk_pool_size_) break;\n\n    // Count cached positions in the evicted source.\n    if (cachehit_output_queue_.has_value()) {\n      size_t evicted_cached = 0;\n      absl::MutexLock item_lock(&chunk_sources_.front()->mutex);\n      for (const auto& cache_chain : chunk_sources_.front()->cache) {\n        const CacheNode* node = cache_chain.get();\n        while (node) {\n          ++evicted_cached;\n          node = node->next.get();\n        }\n      }\n      cached_positions_.fetch_sub(evicted_cached, std::memory_order_acq_rel);\n    }\n\n    // Remove the oldest chunk source (front of deque).\n    chunk_sources_.pop_front();\n  }\n\n  // Update stream shuffler bounds with the sliding window.\n  size_t window_start = chunk_sources_.front()->start_chunk_index;\n  size_t new_lower_bound = new_upper_bound > chunk_pool_size_\n                               ? new_upper_bound - chunk_pool_size_\n                               : window_start;\n  stream_shuffler_.SetUpperBound(new_upper_bound);\n  stream_shuffler_.SetLowerBound(new_lower_bound);\n}\n\nStageMetricProto ShufflingChunkPool::FlushMetrics() {\n  StageMetricProto stage_metric;\n  // Aggregate source ingestion load metrics from all ingestion threads.\n  LoadMetricProto ingestion_load;\n  ingestion_load.set_name(\"source_ingestion\");\n  for (const auto& context : source_ingestion_thread_contexts_) {\n    UpdateFrom(ingestion_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(ingestion_load);\n\n  // Aggregate chunk loading load metrics from all chunk loading threads.\n  LoadMetricProto chunk_loading_load;\n  chunk_loading_load.set_name(\"chunk_loading\");\n  for (const auto& context : chunk_loading_thread_contexts_) {\n    UpdateFrom(chunk_loading_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(chunk_loading_load);\n\n  // Get chunk sources statistics and pool state.\n  {\n    absl::MutexLock lock(&chunk_sources_mutex_);\n    auto* chunk_sources_metric = stage_metric.add_gauge_metrics();\n    chunk_sources_metric->set_name(\"chunk_sources\");\n    chunk_sources_metric->set_value(\n        static_cast<uint64_t>(chunk_sources_.size()));\n\n    size_t upper = 0;\n    size_t current = 0;\n    if (!chunk_sources_.empty()) {\n      const auto& first = chunk_sources_.front();\n      const auto& last = chunk_sources_.back();\n      upper = last->start_chunk_index + last->source->GetChunkCount();\n      current = upper - first->start_chunk_index;\n    }\n\n    auto* current_chunks_metric = stage_metric.add_gauge_metrics();\n    current_chunks_metric->set_name(\"chunks_current\");\n    current_chunks_metric->set_value(static_cast<uint64_t>(current));\n    current_chunks_metric->set_capacity(\n        static_cast<uint64_t>(chunk_pool_size_));\n\n    auto* total_chunks_metric = stage_metric.add_gauge_metrics();\n    total_chunks_metric->set_name(\"chunks_total\");\n    total_chunks_metric->set_value(static_cast<uint64_t>(upper));\n  }\n\n  // Get anchor-related metrics.\n  {\n    absl::MutexLock lock(&anchor_mutex_);\n    auto* chunks_since_anchor_metric = stage_metric.add_gauge_metrics();\n    chunks_since_anchor_metric->set_name(\"chunks_since_anchor\");\n    chunks_since_anchor_metric->set_value(chunks_since_anchor_);\n    stage_metric.set_anchor(anchor_);\n  }\n\n  auto* dropped_metric = stage_metric.add_count_metrics();\n  dropped_metric->set_name(\"dropped\");\n  dropped_metric->set_count(\n      dropped_chunks_metric_.exchange(0, std::memory_order_acq_rel));\n\n  // Hanse sampling and shuffler metrics.\n  {\n    auto* hits = stage_metric.add_count_metrics();\n    hits->set_name(\"hanse_cache_hits\");\n    hits->set_count(hanse_cache_hits_.exchange(0, std::memory_order_acq_rel));\n\n    auto* misses = stage_metric.add_count_metrics();\n    misses->set_name(\"hanse_cache_misses\");\n    misses->set_count(\n        hanse_cache_misses_.exchange(0, std::memory_order_acq_rel));\n\n    auto* rejected = stage_metric.add_count_metrics();\n    rejected->set_name(\"hanse_rejected\");\n    rejected->set_count(hanse_rejected_.exchange(0, std::memory_order_acq_rel));\n\n    auto* resh = stage_metric.add_count_metrics();\n    resh->set_name(\"reshuffles\");\n    resh->set_count(reshuffles_.exchange(0, std::memory_order_acq_rel));\n  }\n\n  // Position cache metrics.\n  if (cachehit_output_queue_.has_value()) {\n    LoadMetricProto caching_load;\n    caching_load.set_name(\"caching\");\n    for (const auto& context : caching_thread_contexts_) {\n      UpdateFrom(caching_load, context->load_metric_updater.FlushMetrics());\n    }\n    *stage_metric.add_load_metrics() = std::move(caching_load);\n\n    auto* cache_hits = stage_metric.add_count_metrics();\n    cache_hits->set_name(\"cache_hits\");\n    cache_hits->set_count(cache_hits_.exchange(0, std::memory_order_acq_rel));\n\n    auto* cache_misses = stage_metric.add_count_metrics();\n    cache_misses->set_name(\"cache_misses\");\n    cache_misses->set_count(\n        cache_misses_.exchange(0, std::memory_order_acq_rel));\n\n    auto* mismatched = stage_metric.add_count_metrics();\n    mismatched->set_name(\"mismatched_use_counts\");\n    mismatched->set_count(\n        mismatched_use_counts_.exchange(0, std::memory_order_acq_rel));\n\n    auto* newly_cached = stage_metric.add_count_metrics();\n    newly_cached->set_name(\"newly_cached\");\n    newly_cached->set_count(\n        newly_cached_.exchange(0, std::memory_order_acq_rel));\n\n    auto* dropped = stage_metric.add_count_metrics();\n    dropped->set_name(\"dropped_cache_positions\");\n    dropped->set_count(\n        dropped_cache_positions_.exchange(0, std::memory_order_acq_rel));\n\n    auto* not_found = stage_metric.add_count_metrics();\n    not_found->set_name(\"chunk_source_not_found\");\n    not_found->set_count(\n        chunk_source_not_found_.exchange(0, std::memory_order_acq_rel));\n\n    auto* cached = stage_metric.add_gauge_metrics();\n    cached->set_name(\"cached_positions\");\n    cached->set_value(cached_positions_.load(std::memory_order_acquire));\n    cached->set_capacity(config_.position_cache_size());\n  }\n\n  {\n    absl::MutexLock lock(&chunk_sources_mutex_);\n    if (chunk_weight_stats_.count() > 0) {\n      chunk_weight_stats_.set_name(\"chunk_weight\");\n      UpdateFrom(*stage_metric.add_statistics_metrics(), chunk_weight_stats_);\n    }\n    chunk_weight_stats_.Clear();\n  }\n\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(primary_output_name_, *output_queue());\n  if (cachehit_output_queue_.has_value()) {\n    *stage_metric.add_queue_metrics() =\n        MetricsFromQueue(*cachehit_output_name_, *cachehit_output_queue_);\n  }\n  return stage_metric;\n}\n\nstd::pair<std::string, int> ShufflingChunkPool::ResetAnchor() {\n  absl::MutexLock anchor_lock(&anchor_mutex_);\n  absl::MutexLock sources_lock(&chunk_sources_mutex_);\n\n  if (chunk_sources_.empty()) {\n    int previous_count = chunks_since_anchor_.exchange(0);\n    return {anchor_, previous_count};\n  }\n\n  anchor_ = chunk_sources_.back()->source->GetChunkSortKey();\n  int previous_count = chunks_since_anchor_.exchange(0);\n  return {anchor_, previous_count};\n}\n\nint ShufflingChunkPool::ChunksSinceAnchor() { return chunks_since_anchor_; }\n\nstd::string ShufflingChunkPool::CurrentAnchor() {\n  absl::MutexLock lock(&anchor_mutex_);\n  return anchor_;\n}\n\nvoid ShufflingChunkPool::SetAnchor(std::string_view anchor) {\n  absl::MutexLock lock(&anchor_mutex_);\n  anchor_ = anchor;\n}\n\nstd::optional<StageControlResponse> ShufflingChunkPool::Control(\n    const StageControlRequest& request) {\n  if (!request.has_chunk_pool_request()) {\n    return std::nullopt;\n  }\n\n  const auto& chunk_request = request.chunk_pool_request();\n  StageControlResponse response;\n  auto* chunk_response = response.mutable_chunk_pool_response();\n\n  if (chunk_request.reset_chunk_anchor()) {\n    auto [anchor, chunks] = ResetAnchor();\n    chunk_response->set_chunk_anchor(anchor);\n    chunk_response->set_chunks_since_anchor(chunks);\n    return response;\n  }\n\n  if (chunk_request.has_set_chunk_anchor()) {\n    SetAnchor(chunk_request.set_chunk_anchor());\n    chunk_response->set_chunk_anchor(chunk_request.set_chunk_anchor());\n    chunk_response->set_chunks_since_anchor(ChunksSinceAnchor());\n    return response;\n  }\n\n  chunk_response->set_chunk_anchor(CurrentAnchor());\n  chunk_response->set_chunks_since_anchor(ChunksSinceAnchor());\n  return response;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/shuffling_chunk_pool.h",
    "content": "#pragma once\n\n#include <atomic>\n#include <filesystem>\n#include <memory>\n#include <optional>\n#include <stop_token>\n#include <string>\n#include <string_view>\n#include <thread>\n#include <variant>\n#include <vector>\n\n#include \"absl/base/thread_annotations.h\"\n#include \"absl/container/flat_hash_set.h\"\n#include \"absl/random/random.h\"\n#include \"absl/synchronization/mutex.h\"\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/stage.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/stage_control.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/load_metric.h\"\n#include \"utils/queue.h\"\n#include \"utils/stream_shuffler.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\nclass ShufflingChunkPool : public Stage {\n public:\n  explicit ShufflingChunkPool(const ShufflingChunkPoolConfig& config);\n  ~ShufflingChunkPool();\n\n  void Start() override;\n  void Stop() override;\n  void SetInputs(absl::Span<QueueBase* const> inputs) override;\n  QueueBase* GetOutput(std::string_view name) override;\n\n  StageMetricProto FlushMetrics() override;\n\n  std::optional<StageControlResponse> Control(\n      const StageControlRequest& request) override;\n\n  // Anchor management methods for tracking chunks since a specific point.\n  std::pair<std::string, int> ResetAnchor();\n  int ChunksSinceAnchor();\n  std::string CurrentAnchor();\n  void SetAnchor(std::string_view anchor);\n\n  Queue<ChunkSourceWithPhase>* input_queue() { return primary_input_queue_; }\n  Queue<TrainingChunk>* output_queue() { return &primary_output_queue_; }\n\n private:\n  struct CacheNode {\n    FrameType frame;\n    std::unique_ptr<CacheNode> next;\n  };\n\n  struct ChunkSourceItem {\n    mutable absl::Mutex mutex;\n    size_t start_chunk_index;\n    std::unique_ptr<ChunkSource> source;\n    absl::flat_hash_set<size_t> dropped_chunks ABSL_GUARDED_BY(mutex);\n    // Per-chunk counters and cached weights.\n    std::vector<uint16_t> use_counts ABSL_GUARDED_BY(mutex);\n    std::vector<float> weight ABSL_GUARDED_BY(mutex);\n    std::vector<std::unique_ptr<CacheNode>> cache ABSL_GUARDED_BY(mutex);\n  };\n\n  struct SourceIngestionThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  struct ChunkLoadingThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  struct CachingThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  std::vector<std::unique_ptr<ChunkSource>> InitializeChunkSources();\n  void ProcessInputFiles(\n      std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources);\n  void SourceIngestionWorker(std::stop_token stop_token,\n                             SourceIngestionThreadContext* context);\n  void OutputWorker(std::stop_token stop_token,\n                    ChunkLoadingThreadContext* context);\n  void CachingWorker(std::stop_token stop_token, CachingThreadContext* context);\n  void AddNewChunkSource(std::unique_ptr<ChunkSource> source)\n      ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_);\n  std::optional<std::variant<TrainingChunk, FrameType>> GetNextChunkData();\n\n  enum class ChunkStatus { kOk, kRetry, kEnd };\n  struct ChunkData;\n\n  ChunkStatus GetChunkInfo(ChunkData& out_chunk_data);\n  bool LoadChunkData(ChunkData& chunk_data);\n  bool HanseAccept(ChunkData& chunk_data);\n  float ComputeChunkWeight(absl::Span<const FrameType> frames) const;\n  double ComputeHanseProbability(float weight, float max_weight) const;\n\n  Queue<ChunkSourceWithPhase>* primary_input_queue_ = nullptr;\n  Queue<CacheRequest>* cache_request_queue_ = nullptr;\n  std::string primary_output_name_;\n  Queue<TrainingChunk> primary_output_queue_;\n  std::optional<std::string> cachehit_output_name_;\n  std::optional<Queue<FrameType>> cachehit_output_queue_;\n\n  const size_t chunk_pool_size_;\n  const ShufflingChunkPoolConfig config_;\n  // stop_source_ must be declared before ThreadPools that reference it.\n  std::stop_source stop_source_;\n  ThreadPool source_ingestion_pool_;\n  ThreadPool chunk_loading_pool_;\n  ThreadPool caching_pool_;\n\n  std::atomic<int64_t> dropped_chunks_metric_{0};\n\n  absl::Mutex chunk_sources_mutex_;\n  std::deque<std::shared_ptr<ChunkSourceItem>> chunk_sources_\n      ABSL_GUARDED_BY(chunk_sources_mutex_);\n  StreamShuffler stream_shuffler_ ABSL_GUARDED_BY(chunk_sources_mutex_);\n  float max_weight_ ABSL_GUARDED_BY(chunk_sources_mutex_) = 0.0f;\n  std::jthread initialization_thread_;\n  std::vector<std::unique_ptr<SourceIngestionThreadContext>>\n      source_ingestion_thread_contexts_;\n  std::vector<std::unique_ptr<ChunkLoadingThreadContext>>\n      chunk_loading_thread_contexts_;\n  std::vector<std::unique_ptr<CachingThreadContext>> caching_thread_contexts_;\n\n  // Anchor-related members for tracking chunks since a specific point.\n  absl::Mutex anchor_mutex_;\n  std::string anchor_ ABSL_GUARDED_BY(anchor_mutex_);\n  std::atomic<int> chunks_since_anchor_{0};\n\n  // Thread-local RNG for Hanse sampling.\n  static thread_local absl::BitGen bitgen_;\n\n  // Metrics counters.\n  std::atomic<uint64_t> hanse_cache_hits_{0};\n  std::atomic<uint64_t> hanse_cache_misses_{0};\n  std::atomic<uint64_t> hanse_rejected_{0};\n  std::atomic<uint64_t> reshuffles_{0};\n  std::atomic<uint64_t> cache_hits_{0};\n  std::atomic<uint64_t> cache_misses_{0};\n  std::atomic<uint64_t> mismatched_use_counts_{0};\n  std::atomic<uint64_t> newly_cached_{0};\n  std::atomic<uint64_t> dropped_cache_positions_{0};\n  std::atomic<uint64_t> chunk_source_not_found_{0};\n  std::atomic<uint64_t> cached_positions_{0};\n\n  StatisticsProtoDouble chunk_weight_stats_\n      ABSL_GUARDED_BY(chunk_sources_mutex_);\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/shuffling_chunk_pool_test.cc",
    "content": "// ABOUTME: Comprehensive unit tests for the ShufflingChunkPool class\n// ABOUTME: Tests chunk source management, output workers, and dynamic windowing\n\n#include \"loader/stages/shuffling_chunk_pool.h\"\n\n#include <absl/cleanup/cleanup.h>\n#include <absl/log/log.h>\n#include <gtest/gtest.h>\n\n#include <chrono>\n#include <cstdint>\n#include <cstring>\n#include <memory>\n#include <set>\n#include <string>\n#include <thread>\n#include <utility>\n#include <vector>\n\n#include \"loader/stages/training_chunk.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\n// Mock ChunkSource for testing\nclass MockChunkSource : public ChunkSource {\n public:\n  MockChunkSource(const std::string& sort_key, size_t chunk_count)\n      : sort_key_(sort_key), chunk_count_(chunk_count) {}\n\n  std::string GetChunkSortKey() const override { return sort_key_; }\n  size_t GetChunkCount() const override { return chunk_count_; }\n\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {\n    if (index >= chunk_count_) {\n      throw std::out_of_range(\"Chunk index out of range\");\n    }\n    FrameType frame{};\n    frame.version = static_cast<uint32_t>(index);\n    frame.input_format = 3;\n    return std::vector<FrameType>{frame};\n  }\n\n private:\n  std::string sort_key_;\n  size_t chunk_count_;\n};\n\nclass InvalidChunkSource : public ChunkSource {\n public:\n  explicit InvalidChunkSource(std::string sort_key)\n      : sort_key_(std::move(sort_key)) {}\n\n  std::string GetChunkSortKey() const override { return sort_key_; }\n  size_t GetChunkCount() const override { return 2; }\n\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {\n    if (index >= 2) {\n      throw std::out_of_range(\"Chunk index out of range\");\n    }\n    if (index == 0) {\n      return std::nullopt;\n    }\n    FrameType frame{};\n    frame.version = 42;\n    return std::vector<FrameType>{frame};\n  }\n\n private:\n  std::string sort_key_;\n};\n\nclass ShufflingChunkPoolTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    input_queue_ = std::make_unique<Queue<ChunkSourceWithPhase>>(100);\n    input_producer_ = std::make_unique<Queue<ChunkSourceWithPhase>::Producer>(\n        input_queue_->CreateProducer());\n  }\n\n  void TearDown() override {\n    // Close the producer to close the queue\n    if (input_producer_) input_producer_.reset();\n  }\n\n  // Helper to add a mock chunk source to the input queue\n  void AddMockChunkSourceToQueue(const std::string& sort_key,\n                                 size_t chunk_count,\n                                 FilePathProvider::MessageType message_type =\n                                     FilePathProvider::MessageType::kFile) {\n    ChunkSourceWithPhase item;\n    item.source = std::make_unique<MockChunkSource>(sort_key, chunk_count);\n    item.message_type = message_type;\n    input_producer_->Put(std::move(item));\n  }\n\n  void MarkInitialScanComplete() {\n    ChunkSourceWithPhase item;\n    item.source = nullptr;  // No source for completion marker\n    item.message_type = FilePathProvider::MessageType::kInitialScanComplete;\n    input_producer_->Put(std::move(item));\n  }\n\n  void CloseInputQueue() {\n    if (input_producer_) input_producer_.reset();\n  }\n\n  ShufflingChunkPoolConfig MakeConfig(int chunk_pool_size,\n                                      int source_ingestion_threads = 1,\n                                      int loading_threads = 1,\n                                      int queue_capacity = 100) const {\n    ShufflingChunkPoolConfig config;\n    config.set_chunk_pool_size(chunk_pool_size);\n    config.set_source_ingestion_threads(source_ingestion_threads);\n    config.set_chunk_loading_threads(loading_threads);\n    config.mutable_output()->set_queue_capacity(queue_capacity);\n    return config;\n  }\n\n  std::unique_ptr<Queue<ChunkSourceWithPhase>> input_queue_;\n  std::unique_ptr<Queue<ChunkSourceWithPhase>::Producer> input_producer_;\n};\n\nTEST_F(ShufflingChunkPoolTest, ConstructorCreatesOutputQueue) {\n  // Add some mock chunk sources with enough chunks\n  AddMockChunkSourceToQueue(\"source1\", 50);\n  AddMockChunkSourceToQueue(\"source2\", 60);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Close input queue to stop input worker from waiting\n  CloseInputQueue();\n\n  EXPECT_NE(output_queue, nullptr);\n  EXPECT_EQ(output_queue->Capacity(), 100);\n\n  // Drain output queue to prevent workers from blocking\n  try {\n    while (output_queue->Size() > 0) {\n      output_queue->Get();\n    }\n  } catch (const QueueClosedException&) {\n    // Queue closed, that's fine\n  }\n}\n\nTEST_F(ShufflingChunkPoolTest, HandlesEmptyInputQueue) {\n  // Only mark scan complete, no chunk sources\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  // Constructor should now succeed (initialization is asynchronous)\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  // The initialization thread should handle the error case\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Give the initialization thread time to complete and discover the error\n  std::this_thread::sleep_for(std::chrono::milliseconds(100));\n\n  // Close input queue to clean up\n  CloseInputQueue();\n\n  // Output queue should exist but be closed to signal startup failure when no\n  // chunks were found.\n  EXPECT_NE(output_queue, nullptr);\n  EXPECT_TRUE(output_queue->IsClosed());\n  EXPECT_EQ(output_queue->Size(), 0u);\n}\n\nTEST_F(ShufflingChunkPoolTest, FlushMetricsHandlesEmptyChunkSources) {\n  const int chunk_pool_size = 32;\n  auto config = MakeConfig(chunk_pool_size);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n\n  auto metrics = shuffling_chunk_pool.FlushMetrics();\n  bool found_current = false;\n  bool found_total = false;\n  for (const auto& metric : metrics.gauge_metrics()) {\n    if (metric.name() == \"chunks_current\") {\n      found_current = true;\n      EXPECT_EQ(metric.value(), 0u);\n      EXPECT_EQ(metric.capacity(), static_cast<uint64_t>(chunk_pool_size));\n    } else if (metric.name() == \"chunks_total\") {\n      found_total = true;\n      EXPECT_EQ(metric.value(), 0u);\n    }\n  }\n\n  EXPECT_TRUE(found_current)\n      << \"FlushMetrics should emit chunks_current metric when empty.\";\n  EXPECT_TRUE(found_total)\n      << \"FlushMetrics should emit chunks_total metric when empty.\";\n}\n\nTEST_F(ShufflingChunkPoolTest, FlushMetricsReportsWindowAndTotalCounts) {\n  AddMockChunkSourceToQueue(\"initial\", 30);\n  MarkInitialScanComplete();\n\n  const int chunk_pool_size = 20;\n  ShufflingChunkPool shuffling_chunk_pool(MakeConfig(chunk_pool_size));\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n  output_queue->WaitForSizeAtLeast(1);\n\n  uint64_t current_count = 0;\n  uint64_t total_count = 0;\n  uint64_t current_capacity = 0;\n  bool found_metrics = false;\n  for (int attempt = 0; attempt < 50 && !found_metrics; ++attempt) {\n    auto metrics = shuffling_chunk_pool.FlushMetrics();\n    bool has_current = false;\n    bool has_total = false;\n    for (const auto& metric : metrics.gauge_metrics()) {\n      if (metric.name() == \"chunks_current\") {\n        has_current = true;\n        current_count = metric.value();\n        current_capacity = metric.capacity();\n      } else if (metric.name() == \"chunks_total\") {\n        has_total = true;\n        total_count = metric.value();\n      }\n    }\n    if (has_current && has_total) {\n      found_metrics = true;\n      break;\n    }\n    std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  }\n\n  ASSERT_TRUE(found_metrics)\n      << \"FlushMetrics should report both chunks_current and chunks_total.\";\n  EXPECT_EQ(current_count, 30u);\n  EXPECT_EQ(current_capacity, static_cast<uint64_t>(chunk_pool_size));\n  EXPECT_EQ(total_count, 30u);\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, ProcessesInitialScanChunkSources) {\n  // Create mock chunk sources with enough chunks\n  AddMockChunkSourceToQueue(\"source1\", 30);\n  AddMockChunkSourceToQueue(\"source2\", 40);\n  AddMockChunkSourceToQueue(\"source3\", 50);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  // Test that constructor completes and processes mock chunk sources\n  EXPECT_NO_THROW({\n    ShufflingChunkPool shuffling_chunk_pool(config);\n\n    shuffling_chunk_pool.SetInputs({input_queue_.get()});\n    shuffling_chunk_pool.Start();\n\n    // Close input queue to stop input worker from waiting\n    CloseInputQueue();\n\n    auto* output_queue = shuffling_chunk_pool.output_queue();\n    EXPECT_NE(output_queue, nullptr);\n  });\n}\n\nTEST_F(ShufflingChunkPoolTest, OutputWorkerProducesChunks) {\n  // Create mock chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 10,\n                            FilePathProvider::MessageType::kFile);\n  AddMockChunkSourceToQueue(\"source2\", 15,\n                            FilePathProvider::MessageType::kFile);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  // Close input queue to stop input worker from waiting\n  CloseInputQueue();\n\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Wait for output workers to produce at least one chunk\n  output_queue->WaitForSizeAtLeast(1);\n\n  // Should have some chunks available\n  EXPECT_GT(output_queue->Size(), 0);\n\n  // Get a chunk and verify it's from our mock sources\n  auto chunk = output_queue->Get();\n  EXPECT_FALSE(chunk.frames.empty());\n  EXPECT_TRUE(chunk.sort_key == \"source1\" || chunk.sort_key == \"source2\");\n  EXPECT_EQ(chunk.frames.size(), 1);\n  EXPECT_EQ(chunk.frames.front().version,\n            static_cast<uint32_t>(chunk.index_within_sort_key));\n  EXPECT_EQ(chunk.use_count, 0u);\n}\n\nTEST_F(ShufflingChunkPoolTest, DropsInvalidChunks) {\n  ChunkSourceWithPhase invalid_source;\n  invalid_source.source =\n      std::make_unique<InvalidChunkSource>(\"invalid_source\");\n  invalid_source.message_type = FilePathProvider::MessageType::kFile;\n  input_producer_->Put(std::move(invalid_source));\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(2, /*source_ingestion_threads=*/1,\n                           /*loading_threads=*/1, /*queue_capacity=*/10);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  // Close input queue to stop input worker from waiting\n  CloseInputQueue();\n\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n  output_queue->WaitForSizeAtLeast(1);\n  auto chunk = output_queue->Get();\n\n  EXPECT_EQ(chunk.sort_key, \"invalid_source\");\n  EXPECT_EQ(chunk.index_within_sort_key, 1);\n  EXPECT_EQ(chunk.use_count, 0u);\n  ASSERT_EQ(chunk.frames.size(), 1);\n  EXPECT_EQ(chunk.frames.front().version, 42);\n\n  uint64_t dropped_latest = 0;\n  bool found_dropped = false;\n  for (int attempt = 0; attempt < 50 && !found_dropped; ++attempt) {\n    auto metrics = shuffling_chunk_pool.FlushMetrics();\n    for (const auto& metric : metrics.count_metrics()) {\n      if (metric.name() == \"dropped\" && metric.count() > 0) {\n        dropped_latest = metric.count();\n        found_dropped = true;\n        break;\n      }\n    }\n    if (!found_dropped) {\n      std::this_thread::sleep_for(std::chrono::milliseconds(10));\n    }\n  }\n  ASSERT_TRUE(found_dropped) << \"dropped chunk metrics should be reported\";\n  EXPECT_GE(dropped_latest, 1u);\n}\n\nTEST_F(ShufflingChunkPoolTest, NewChunkSourceProcessing) {\n  // Start with initial scan and one chunk source - use enough chunks to satisfy\n  // window\n  AddMockChunkSourceToQueue(\"initial\", 120);  // More chunks than window\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  // Verify chunks are being produced from initial sources\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n  output_queue->WaitForSizeAtLeast(1);\n  EXPECT_NE(output_queue, nullptr);\n  EXPECT_GT(output_queue->Size(), 0);\n\n  // Add a new chunk source after initialization\n  AddMockChunkSourceToQueue(\"new_source\", 30,\n                            FilePathProvider::MessageType::kFile);\n\n  // Close input queue to stop input worker from waiting for more\n  CloseInputQueue();\n\n  // The chunk set should still be functional and continue producing chunks\n  // from both the initial and new sources\n  EXPECT_GT(output_queue->Size(), 0);\n}\n\nTEST_F(ShufflingChunkPoolTest, ChunkWindowManagement) {\n  // Create more chunks than the window size\n  AddMockChunkSourceToQueue(\"source1\", 30);\n  AddMockChunkSourceToQueue(\"source2\", 30);\n  AddMockChunkSourceToQueue(\"source3\", 30);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(50);\n\n  // Should only keep sources that fit in the window\n  EXPECT_NO_THROW({\n    ShufflingChunkPool shuffling_chunk_pool(config);\n\n    shuffling_chunk_pool.SetInputs({input_queue_.get()});\n    shuffling_chunk_pool.Start();\n\n    // Close input queue to stop input worker from waiting\n    CloseInputQueue();\n\n    auto* output_queue = shuffling_chunk_pool.output_queue();\n    EXPECT_NE(output_queue, nullptr);\n  });\n}\n\n// Test the ShufflingChunkPoolConfig structure\nTEST_F(ShufflingChunkPoolTest, ChunkSorting) {\n  // Add chunk sources in non-sorted order (by sort key)\n  AddMockChunkSourceToQueue(\"source_b\", 20);\n  AddMockChunkSourceToQueue(\"source_a\", 25);\n  AddMockChunkSourceToQueue(\"source_c\", 30);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(70);\n\n  // ShufflingChunkPool should handle sorting internally (newest first)\n  EXPECT_NO_THROW({\n    ShufflingChunkPool shuffling_chunk_pool(config);\n\n    shuffling_chunk_pool.SetInputs({input_queue_.get()});\n    shuffling_chunk_pool.Start();\n\n    // Close input queue to stop input worker from waiting\n    CloseInputQueue();\n\n    auto* output_queue = shuffling_chunk_pool.output_queue();\n    EXPECT_NE(output_queue, nullptr);\n  });\n}\n\nTEST_F(ShufflingChunkPoolTest, StreamShufflerResetWhenExhausted) {\n  // Create a small chunk source to quickly exhaust the shuffler\n  AddMockChunkSourceToQueue(\"source1\", 3);  // Only 3 chunks for faster testing\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(3, /*source_ingestion_threads=*/1,\n                           /*loading_threads=*/1,\n                           /*queue_capacity=*/100);  // Large enough\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Collect chunks continuously and count total chunks received\n  struct ChunkRecord {\n    std::string sort_key;\n    size_t index;\n    uint32_t use_count;\n  };\n  std::vector<ChunkRecord> all_chunks_received;\n\n  // Wait for and collect chunks to test shuffler reset\n  for (size_t i = 0; i < 8; ++i) {\n    output_queue->WaitForSizeAtLeast(1);\n    auto chunk = output_queue->Get();\n    all_chunks_received.push_back(\n        {chunk.sort_key, chunk.index_within_sort_key, chunk.use_count});\n  }\n\n  std::set<std::pair<std::string, size_t>> unique_chunks;\n  bool seen_reuse = false;\n  for (const auto& record : all_chunks_received) {\n    unique_chunks.emplace(record.sort_key, record.index);\n    if (record.use_count > 0) {\n      seen_reuse = true;\n    }\n  }\n\n  // Close input queue to clean up\n  try {\n    CloseInputQueue();\n  } catch (const QueueClosedException&) {\n    // Already closed, that's fine\n  }\n\n  // We should see all 3 unique chunks from our source\n  EXPECT_EQ(unique_chunks.size(), 3) << \"Should see all unique chunks\";\n\n  // If reset works properly, we should receive more than 3 total chunks\n  // (since chunks will repeat after shuffler reset)\n  EXPECT_GT(all_chunks_received.size(), 3)\n      << \"Should get more than 3 chunks total due to shuffler reset, got \"\n      << all_chunks_received.size() << \" chunks\";\n  EXPECT_TRUE(seen_reuse)\n      << \"Expect at least one chunk to report a reuse count\";\n}\n\nTEST_F(ShufflingChunkPoolTest, HanseMetrics_NoRejection_CacheAndReshuffles) {\n  // Single chunk so we will continually reuse the same chunk.\n  AddMockChunkSourceToQueue(\"source1\", 1);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(1, /*source_ingestion_threads=*/1,\n                           /*loading_threads=*/1, /*queue_capacity=*/100);\n  // Enable Hanse sampling with p == 1 to avoid rejections.\n  config.set_hanse_sampling_threshold(1);\n\n  ShufflingChunkPool pool(config);\n\n  pool.SetInputs({input_queue_.get()});\n  pool.Start();\n\n  auto* output_queue = pool.output_queue();\n  // Wait for multiple outputs to exercise cache hits and reshuffles.\n  output_queue->WaitForSizeAtLeast(3);\n  // Drain a few items.\n  for (int i = 0; i < 3; ++i) {\n    auto chunk = output_queue->Get();\n    EXPECT_EQ(chunk.frames.size(), 1u);\n  }\n\n  // Close input to avoid lingering.\n  CloseInputQueue();\n\n  // Flush metrics and validate Hanse counters and reshuffles.\n  auto metrics = pool.FlushMetrics();\n  uint64_t cache_hits = 0, cache_misses = 0, rejected = 0, reshuffles = 0;\n  for (const auto& m : metrics.count_metrics()) {\n    if (m.name() == \"hanse_cache_hits\") cache_hits = m.count();\n    if (m.name() == \"hanse_cache_misses\") cache_misses = m.count();\n    if (m.name() == \"hanse_rejected\") rejected = m.count();\n    if (m.name() == \"reshuffles\") reshuffles = m.count();\n  }\n\n  // First access computes and caches num_records => 1 miss, then hits.\n  EXPECT_EQ(cache_misses, 1u);\n  EXPECT_GE(cache_hits, 1u);\n  // With threshold=1 and one frame, p = 1 => no rejections.\n  EXPECT_EQ(rejected, 0u);\n  // Single chunk repeatedly consumed forces reshuffles.\n  EXPECT_GT(reshuffles, 0u);\n}\n\nTEST_F(ShufflingChunkPoolTest, ExplicitClose) {\n  // Create chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  AddMockChunkSourceToQueue(\"source2\", 30);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(40);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Wait for workers to produce some chunks\n  output_queue->WaitForSizeAtLeast(1);\n\n  // Verify output queue is working before close\n  EXPECT_GT(output_queue->Size(), 0);\n\n  // Explicitly stop the chunk set\n  shuffling_chunk_pool.Stop();\n\n  // Drain all remaining items from the queue\n  while (output_queue->Size() > 0) {\n    output_queue->Get();\n  }\n\n  // Now the queue should be closed and empty, so Get() should throw\n  EXPECT_THROW(output_queue->Get(), QueueClosedException);\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, CloseStopsOutputWorkers) {\n  // Create chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 15);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(15, /*source_ingestion_threads=*/1,\n                           /*loading_threads=*/2, /*queue_capacity=*/50);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Wait for workers to produce chunks\n  output_queue->WaitForSizeAtLeast(1);\n  size_t chunks_before_close = output_queue->Size();\n\n  // Stop the chunk set\n  shuffling_chunk_pool.Stop();\n\n  // Drain any remaining chunks from the queue\n  try {\n    while (output_queue->Size() > 0) {\n      output_queue->Get();\n    }\n  } catch (const QueueClosedException&) {\n    // Expected when queue is empty and closed\n  }\n\n  // Should have had chunks before close\n  EXPECT_GT(chunks_before_close, 0) << \"Should have had chunks before close\";\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, CloseIsIdempotent) {\n  // Create chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n\n  // Stop multiple times - should not crash or cause issues\n  EXPECT_NO_THROW(shuffling_chunk_pool.Stop());\n  EXPECT_NO_THROW(shuffling_chunk_pool.Stop());\n  EXPECT_NO_THROW(shuffling_chunk_pool.Stop());\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, DestructorCallsClose) {\n  // Create chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  // Test that destructor calls Close() and properly shuts down\n  {\n    ShufflingChunkPool shuffling_chunk_pool(config);\n\n    shuffling_chunk_pool.SetInputs({input_queue_.get()});\n    shuffling_chunk_pool.Start();\n    auto* output_queue = shuffling_chunk_pool.output_queue();\n\n    // Wait for workers to produce some chunks\n    output_queue->WaitForSizeAtLeast(1);\n    EXPECT_GT(output_queue->Size(), 0);\n\n    // Close input queue before destructor to allow threads to finish\n    CloseInputQueue();\n\n    // ShufflingChunkPool destructor should be called here, which calls Stop()\n    // and waits for all threads to finish\n  }\n\n  // Test passes if destructor completes without hanging\n  // (we can't test the queue state after destruction since it's destroyed)\n}\n\nTEST_F(ShufflingChunkPoolTest, InputQueueClosureDoesNotCloseOutputQueue) {\n  // Create chunk sources\n  AddMockChunkSourceToQueue(\"source1\", 30);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(30);\n\n  ShufflingChunkPool shuffling_chunk_pool(config);\n\n  shuffling_chunk_pool.SetInputs({input_queue_.get()});\n  shuffling_chunk_pool.Start();\n  auto* output_queue = shuffling_chunk_pool.output_queue();\n\n  // Wait for workers to produce some chunks\n  output_queue->WaitForSizeAtLeast(1);\n  EXPECT_GT(output_queue->Size(), 0);\n\n  // Close input queue (simulating end of file discovery)\n  CloseInputQueue();\n\n  // Output queue should still be functional - workers should continue\n  // producing chunks from existing chunk sources\n\n  // Should still be able to get chunks (queue not closed)\n  EXPECT_NO_THROW(output_queue->Get());\n\n  // Explicitly stop to clean up\n  shuffling_chunk_pool.Stop();\n}\n\nTEST_F(ShufflingChunkPoolTest, BasicAnchorFunctionality) {\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool pool(config);\n\n  pool.SetInputs({input_queue_.get()});\n  pool.Start();\n\n  // Test initial state\n  EXPECT_EQ(pool.ChunksSinceAnchor(), 0);\n  EXPECT_EQ(pool.CurrentAnchor(), \"\");\n\n  // Test SetAnchor and CurrentAnchor\n  pool.SetAnchor(\"test_anchor_key\");\n  EXPECT_EQ(pool.CurrentAnchor(), \"test_anchor_key\");\n  EXPECT_EQ(pool.ChunksSinceAnchor(), 0);  // Should still be 0\n\n  // Test setting different anchor\n  pool.SetAnchor(\"another_key\");\n  EXPECT_EQ(pool.CurrentAnchor(), \"another_key\");\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, ResetAnchor) {\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  MarkInitialScanComplete();\n\n  auto config = MakeConfig(20);\n\n  ShufflingChunkPool pool(config);\n\n  pool.SetInputs({input_queue_.get()});\n  pool.Start();\n\n  // Wait for initialization to complete\n  pool.output_queue()->WaitForSizeAtLeast(1);\n\n  // Now test ResetAnchor\n  auto [anchor, count_before] = pool.ResetAnchor();\n  EXPECT_FALSE(anchor.empty());  // Should have the chunk key\n  EXPECT_EQ(pool.CurrentAnchor(), anchor);\n  EXPECT_EQ(pool.ChunksSinceAnchor(), 0);  // Should be reset to 0\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, AnchorCounterIncrement) {\n  // Don't mark initial scan complete yet - we'll add sources one by one\n\n  auto config = MakeConfig(20);\n\n  // Start with some initial sources and complete scan\n  AddMockChunkSourceToQueue(\"source1\", 20);\n  MarkInitialScanComplete();\n\n  ShufflingChunkPool pool(config);\n\n  pool.SetInputs({input_queue_.get()});\n  pool.Start();\n\n  // Set anchor to a key that won't match our new sources\n  pool.SetAnchor(\"non_matching_key\");\n\n  // Wait for initial load to complete\n  pool.output_queue()->WaitForSizeAtLeast(1);\n\n  // Now add new sources (these should increment the counter)\n  // Note: We can't add more sources after initial scan complete in the current\n  // setup So we'll test the counter after the initial load\n\n  int final_count = pool.ChunksSinceAnchor();\n\n  // Counter should have incremented during initial load since anchor doesn't\n  // match\n  EXPECT_GT(final_count, 0);\n  EXPECT_EQ(pool.CurrentAnchor(), \"non_matching_key\");  // Anchor unchanged\n\n  CloseInputQueue();\n}\n\nTEST_F(ShufflingChunkPoolTest, AnchorCounterResetDuringInitialLoad) {\n  // Test the special case where anchor is encountered during initial backward\n  // processing\n  AddMockChunkSourceToQueue(\"source_c\", 10);  // newest\n  AddMockChunkSourceToQueue(\"source_b\", 15);  // middle\n  AddMockChunkSourceToQueue(\"source_a\", 20);  // oldest\n\n  auto config = MakeConfig(45);\n\n  ShufflingChunkPool pool(config);\n\n  pool.SetInputs({input_queue_.get()});\n  pool.Start();\n\n  // Set anchor to middle source before marking scan complete\n  pool.SetAnchor(\"source_b\");\n\n  // Mark scan complete to trigger initial processing\n  MarkInitialScanComplete();\n\n  // Wait for initial load to complete\n  pool.output_queue()->WaitForSizeAtLeast(1);\n\n  int final_count = pool.ChunksSinceAnchor();\n\n  // Should only count chunks from source_c (10 chunks) since it is newer than\n  // the anchor.\n  EXPECT_EQ(final_count, 10);\n  EXPECT_EQ(pool.CurrentAnchor(), \"source_b\");\n\n  CloseInputQueue();\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/shuffling_frame_sampler.cc",
    "content": "#include \"loader/stages/shuffling_frame_sampler.h\"\n\n#include <utility>\n\n#include \"absl/algorithm/container.h\"\n#include \"absl/log/log.h\"\n#include \"absl/random/uniform_int_distribution.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nShufflingFrameSampler::ShufflingFrameSampler(\n    const ShufflingFrameSamplerConfig& config)\n    : SingleInputStage<ShufflingFrameSamplerConfig, InputType>(config),\n      SingleOutputStage<OutputType>(config.output()),\n      reservoir_size_per_thread_(config.reservoir_size_per_thread()),\n      thread_pool_(config.threads(), ThreadPoolOptions{}) {\n  LOG(INFO) << \"Initializing ShufflingFrameSampler with \" << config.threads()\n            << \" threads, reservoir size \"\n            << config.reservoir_size_per_thread();\n\n  // Initialize thread contexts but don't start worker threads yet.\n  thread_contexts_.reserve(config.threads());\n  for (size_t i = 0; i < config.threads(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n}\n\nShufflingFrameSampler::~ShufflingFrameSampler() { Stop(); }\n\nvoid ShufflingFrameSampler::Start() {\n  LOG(INFO) << \"Starting ShufflingFrameSampler worker threads.\";\n  for (size_t i = 0; i < thread_contexts_.size(); ++i) {\n    thread_pool_.Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, thread_contexts_[i].get());\n    });\n  }\n}\n\nvoid ShufflingFrameSampler::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n\n  LOG(INFO) << \"Stopping ShufflingFrameSampler.\";\n  thread_pool_.Shutdown();\n  output_queue()->Close();\n  LOG(INFO) << \"ShufflingFrameSampler stopped.\";\n}\n\nvoid ShufflingFrameSampler::Worker(std::stop_token stop_token,\n                                   ThreadContext* context) {\n  // Create producer early so that if input queue closes during reservoir\n  // prefilling, the producer will be destroyed and close the output queue.\n  auto producer = output_queue()->CreateProducer();\n  absl::FixedArray<FrameType> reservoir(reservoir_size_per_thread_);\n\n  try {\n    // Phase 1: Prefill the reservoir\n    LOG(INFO) << \"ShufflingFrameSampler worker prefilling reservoir\";\n    absl::c_generate(reservoir, [this, context, stop_token]() {\n      LoadMetricPauser pauser(context->load_metric_updater);\n      return input_queue()->Get(stop_token);\n    });\n\n    // Phase 2: Main sampling loop\n    MainSamplingLoop(stop_token, reservoir, producer, context);\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"ShufflingFrameSampler worker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"ShufflingFrameSampler worker stopping, request cancelled.\";\n  }\n}\n\nvoid ShufflingFrameSampler::MainSamplingLoop(\n    std::stop_token stop_token, absl::FixedArray<FrameType>& reservoir,\n    Queue<OutputType>::Producer& producer, ThreadContext* context) {\n  absl::uniform_int_distribution<size_t> dist(0, reservoir.size() - 1);\n\n  while (true) {\n    const size_t random_index = dist(gen_);\n    {\n      LoadMetricPauser pauser(context->load_metric_updater);\n      producer.Put(std::move(reservoir[random_index]), stop_token);\n    }\n    {\n      LoadMetricPauser pauser(context->load_metric_updater);\n      reservoir[random_index] = input_queue()->Get(stop_token);\n    }\n  }\n}\n\nStageMetricProto ShufflingFrameSampler::FlushMetrics() {\n  StageMetricProto stage_metric;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(aggregated_load);\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *output_queue());\n  return stage_metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/shuffling_frame_sampler.h",
    "content": "// ABOUTME: Stage that provides shuffled frames using reservoir sampling.\n// ABOUTME: Takes FrameType frames and outputs them in randomized order.\n#pragma once\n\n#include <atomic>\n#include <cstddef>\n#include <memory>\n#include <stop_token>\n#include <vector>\n\n#include \"absl/container/fixed_array.h\"\n#include \"absl/random/random.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/frame_type.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Worker that implements reservoir sampling for training frames.\n// Takes FrameType frames as input and outputs them in shuffled order\n// using reservoir sampling algorithm.\nclass ShufflingFrameSampler\n    : public SingleInputStage<ShufflingFrameSamplerConfig, FrameType>,\n      public SingleOutputStage<FrameType> {\n public:\n  using InputType = FrameType;\n  using OutputType = FrameType;\n\n  explicit ShufflingFrameSampler(const ShufflingFrameSamplerConfig& config);\n  ~ShufflingFrameSampler();\n\n  void Start() override;\n  void Stop() override;\n  StageMetricProto FlushMetrics() override;\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, ThreadContext* context);\n  void MainSamplingLoop(std::stop_token stop_token,\n                        absl::FixedArray<FrameType>& reservoir,\n                        Queue<OutputType>::Producer& producer,\n                        ThreadContext* context);\n\n  size_t reservoir_size_per_thread_;\n  absl::BitGen gen_;\n  // thread_contexts_ must be declared before thread_pool_ to ensure\n  // thread_pool_ is destroyed first (stopping threads before contexts).\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  ThreadPool thread_pool_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/shuffling_frame_sampler_test.cc",
    "content": "#include \"loader/stages/shuffling_frame_sampler.h\"\n\n#include <set>\n#include <vector>\n\n#include \"gtest/gtest.h\"\n#include \"libs/lc0/src/trainingdata/trainingdata_v6.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nclass ShufflingFrameSamplerTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    input_queue_ = std::make_unique<Queue<FrameType>>(100);\n    config_.set_reservoir_size_per_thread(10);  // Small size for testing\n    config_.mutable_output()->set_queue_capacity(20);\n  }\n\n  FrameType CreateTestFrame(uint32_t version) {\n    FrameType frame{};\n    frame.version = version;\n    frame.input_format = 3;\n    frame.root_q = 0.5f;\n    return frame;\n  }\n\n  std::unique_ptr<Queue<FrameType>> input_queue_;\n  ShufflingFrameSamplerConfig config_;\n};\n\nTEST_F(ShufflingFrameSamplerTest, OutputsNoFramesWithSmallInput) {\n  ShufflingFrameSampler sampler(config_);\n  sampler.SetInputs({input_queue_.get()});\n  sampler.Start();\n\n  // Send 5 frames (less than reservoir size)\n  auto producer = input_queue_->CreateProducer();\n  std::vector<uint32_t> input_versions = {1, 2, 3, 4, 5};\n  for (auto version : input_versions) {\n    producer.Put(CreateTestFrame(version));\n  }\n  producer.Close();\n\n  // Collect all output frames\n  std::set<uint32_t> output_versions;\n  try {\n    while (true) {\n      auto frame = sampler.output_queue()->Get();\n      output_versions.insert(frame.version);\n    }\n  } catch (const QueueClosedException&) {\n    // Expected when queue is closed\n  }\n\n  // With fewer inputs than reservoir size, no frames should be output\n  // (they remain in the reservoir)\n  EXPECT_EQ(output_versions.size(), 0);\n}\n\nTEST_F(ShufflingFrameSamplerTest, OutputsFramesWithLargeInput) {\n  ShufflingFrameSampler sampler(config_);\n  sampler.SetInputs({input_queue_.get()});\n  sampler.Start();\n\n  // Send 20 frames (more than reservoir size of 10)\n  auto producer = input_queue_->CreateProducer();\n  std::vector<uint32_t> input_versions;\n  for (uint32_t i = 1; i <= 20; ++i) {\n    input_versions.push_back(i);\n    producer.Put(CreateTestFrame(i));\n  }\n  producer.Close();\n\n  // Collect all output frames\n  std::set<uint32_t> output_versions;\n  try {\n    while (true) {\n      auto frame = sampler.output_queue()->Get();\n      output_versions.insert(frame.version);\n    }\n  } catch (const QueueClosedException&) {\n    // Expected when queue is closed\n  }\n\n  // Should output exactly 11 frames (10 during sampling + 1 final frame before\n  // queue closes)\n  EXPECT_EQ(output_versions.size(), 11);\n\n  // All output frames should be from the input set\n  for (auto version : output_versions) {\n    EXPECT_TRUE(std::find(input_versions.begin(), input_versions.end(),\n                          version) != input_versions.end());\n  }\n}\n\nTEST_F(ShufflingFrameSamplerTest, HandlesEmptyInput) {\n  ShufflingFrameSampler sampler(config_);\n  sampler.SetInputs({input_queue_.get()});\n  sampler.Start();\n\n  // Close input queue without sending data\n  input_queue_->Close();\n\n  // Should not output any frames\n  EXPECT_THROW(sampler.output_queue()->Get(), QueueClosedException);\n}\n\nTEST_F(ShufflingFrameSamplerTest, HandlesExactReservoirSize) {\n  ShufflingFrameSampler sampler(config_);\n  sampler.SetInputs({input_queue_.get()});\n  sampler.Start();\n\n  // Send exactly reservoir_size_per_thread frames\n  auto producer = input_queue_->CreateProducer();\n  std::vector<uint32_t> input_versions;\n  for (uint32_t i = 1; i <= config_.reservoir_size_per_thread(); ++i) {\n    input_versions.push_back(i);\n    producer.Put(CreateTestFrame(i));\n  }\n  producer.Close();\n\n  // Collect all output frames\n  std::set<uint32_t> output_versions;\n  try {\n    while (true) {\n      auto frame = sampler.output_queue()->Get();\n      output_versions.insert(frame.version);\n    }\n  } catch (const QueueClosedException&) {\n    // Expected when queue is closed\n  }\n\n  // With exactly reservoir size frames, 1 frame should be output\n  // (fills reservoir, then queue closes during first sampling attempt)\n  EXPECT_EQ(output_versions.size(), 1);\n}\n\nTEST_F(ShufflingFrameSamplerTest, PreservesFrameData) {\n  config_.set_reservoir_size_per_thread(2);\n  ShufflingFrameSampler sampler(config_);\n  sampler.SetInputs({input_queue_.get()});\n  sampler.Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Create frames with specific data - need more than reservoir size\n  FrameType frame1 = CreateTestFrame(100);\n  frame1.root_q = 0.1f;\n  frame1.input_format = 1;\n\n  FrameType frame2 = CreateTestFrame(200);\n  frame2.root_q = 0.2f;\n  frame2.input_format = 2;\n\n  FrameType frame3 = CreateTestFrame(300);\n  frame3.root_q = 0.3f;\n  frame3.input_format = 3;\n\n  producer.Put(frame1);\n  producer.Put(frame2);\n  producer.Put(frame3);  // This will cause frame1 to be output\n  producer.Close();\n\n  // Verify frame data is preserved\n  std::vector<FrameType> output_frames;\n  try {\n    while (true) {\n      output_frames.push_back(sampler.output_queue()->Get());\n    }\n  } catch (const QueueClosedException&) {\n    // Expected\n  }\n\n  EXPECT_EQ(output_frames.size(), 2);\n  // Should be frames that were displaced from the reservoir during sampling\n  std::set<uint32_t> output_frame_versions;\n  for (const auto& frame : output_frames) {\n    output_frame_versions.insert(frame.version);\n    // Verify frame data is preserved\n    if (frame.version == 100) {\n      EXPECT_EQ(frame.root_q, 0.1f);\n      EXPECT_EQ(frame.input_format, 1);\n    } else if (frame.version == 200) {\n      EXPECT_EQ(frame.root_q, 0.2f);\n      EXPECT_EQ(frame.input_format, 2);\n    }\n  }\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/simple_chunk_extractor.cc",
    "content": "#include \"loader/stages/simple_chunk_extractor.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/log/log.h>\n\n#include <numeric>\n\n#include \"loader/data_loader_metrics.h\"\n\nnamespace lczero {\nnamespace training {\n\nSimpleChunkExtractor::SimpleChunkExtractor(\n    const SimpleChunkExtractorConfig& config)\n    : SingleInputStage<SimpleChunkExtractorConfig, ChunkSourceWithPhase>(\n          config),\n      SingleOutputStage<TrainingChunk>(config.output()),\n      bitgen_(absl::MakeSeedSeq()) {}\n\nSimpleChunkExtractor::~SimpleChunkExtractor() { Stop(); }\n\nvoid SimpleChunkExtractor::Start() {\n  thread_pool_.Enqueue(\n      [this](std::stop_token stop_token) { Worker(stop_token); });\n}\n\nvoid SimpleChunkExtractor::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n  LOG(INFO) << \"Stopping SimpleChunkExtractor.\";\n  thread_pool_.Shutdown();\n  output_queue()->Close();\n}\n\nvoid SimpleChunkExtractor::Worker(std::stop_token stop_token) {\n  auto producer = output_queue()->CreateProducer();\n\n  try {\n    while (true) {\n      auto item = input_queue()->Get(stop_token);\n      if (item.message_type != FilePathProvider::MessageType::kFile ||\n          !item.source) {\n        continue;\n      }\n\n      ProcessSource(producer, std::move(item.source), stop_token);\n    }\n  } catch (const QueueClosedException&) {\n  }\n}\n\nvoid SimpleChunkExtractor::ProcessSource(\n    Queue<TrainingChunk>::Producer& producer,\n    std::unique_ptr<ChunkSource> source, std::stop_token stop_token) {\n  const size_t chunk_count = source->GetChunkCount();\n  if (chunk_count == 0) return;\n\n  std::vector<size_t> indices(chunk_count);\n  std::iota(indices.begin(), indices.end(), 0);\n  absl::c_shuffle(indices, bitgen_);\n\n  const std::string sort_key = source->GetChunkSortKey();\n  for (size_t idx : indices) {\n    if (auto chunk = LoadChunk(*source, sort_key, idx)) {\n      producer.Put(std::move(*chunk), stop_token);\n      ++chunks_processed_;\n    }\n  }\n  ++sources_processed_;\n}\n\nstd::optional<TrainingChunk> SimpleChunkExtractor::LoadChunk(\n    ChunkSource& source, const std::string& sort_key, size_t index) {\n  auto data = source.GetChunkData(index);\n  if (!data || data->empty()) {\n    ++chunks_dropped_;\n    return std::nullopt;\n  }\n\n  TrainingChunk chunk;\n  chunk.sort_key = sort_key;\n  chunk.index_within_sort_key = index;\n  chunk.global_index = chunks_processed_;\n  chunk.use_count = 0;\n  chunk.frames = std::move(*data);\n\n  return chunk;\n}\n\nStageMetricProto SimpleChunkExtractor::FlushMetrics() {\n  StageMetricProto metric;\n  auto add_count = [&](const char* name, std::atomic<uint64_t>& counter) {\n    auto* m = metric.add_count_metrics();\n    m->set_name(name);\n    m->set_count(counter.exchange(0));\n  };\n\n  add_count(\"chunks_processed\", chunks_processed_);\n  add_count(\"chunks_dropped\", chunks_dropped_);\n  add_count(\"sources_processed\", sources_processed_);\n\n  *metric.add_queue_metrics() = MetricsFromQueue(\"output\", *output_queue());\n  return metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/simple_chunk_extractor.h",
    "content": "#pragma once\n\n#include <atomic>\n#include <memory>\n#include <optional>\n#include <stop_token>\n#include <string>\n#include <string_view>\n\n#include \"absl/random/random.h\"\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/stage.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Single-threaded stage that shuffles chunks within each source.\nclass SimpleChunkExtractor\n    : public SingleInputStage<SimpleChunkExtractorConfig, ChunkSourceWithPhase>,\n      public SingleOutputStage<TrainingChunk> {\n public:\n  explicit SimpleChunkExtractor(const SimpleChunkExtractorConfig& config);\n  ~SimpleChunkExtractor();\n\n  void Start() override;\n  void Stop() override;\n  StageMetricProto FlushMetrics() override;\n\n private:\n  void Worker(std::stop_token stop_token);\n  void ProcessSource(Queue<TrainingChunk>::Producer& producer,\n                     std::unique_ptr<ChunkSource> source,\n                     std::stop_token stop_token);\n  std::optional<TrainingChunk> LoadChunk(ChunkSource& source,\n                                         const std::string& sort_key,\n                                         size_t index);\n\n  std::atomic<uint64_t> chunks_processed_{0};\n  std::atomic<uint64_t> chunks_dropped_{0};\n  std::atomic<uint64_t> sources_processed_{0};\n  absl::BitGen bitgen_;\n  ThreadPool thread_pool_{1};\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/simple_chunk_extractor_test.cc",
    "content": "#include \"loader/stages/simple_chunk_extractor.h\"\n\n#include <gmock/gmock.h>\n#include <gtest/gtest.h>\n\n#include <memory>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/file_path_provider.h\"\n#include \"loader/stages/stage.h\"\n#include \"loader/stages/training_chunk.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\nnamespace {\n\n// Mock chunk source for testing.\nclass MockChunkSource : public ChunkSource {\n public:\n  MockChunkSource(std::string sort_key, size_t chunk_count)\n      : sort_key_(std::move(sort_key)), chunk_count_(chunk_count) {\n    // Pre-generate chunk data.\n    for (size_t i = 0; i < chunk_count; ++i) {\n      chunks_.emplace_back(10);  // 10 frames per chunk.\n    }\n  }\n\n  std::string GetChunkSortKey() const override { return sort_key_; }\n  size_t GetChunkCount() const override { return chunk_count_; }\n\n  std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {\n    if (index >= chunks_.size()) return std::nullopt;\n    return chunks_[index];\n  }\n\n private:\n  std::string sort_key_;\n  size_t chunk_count_;\n  std::vector<std::vector<FrameType>> chunks_;\n};\n\nclass SimpleChunkExtractorTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    // Create input queue.\n    input_queue_ = std::make_unique<Queue<ChunkSourceWithPhase>>(10);\n\n    // Add a dummy stage to the registry.\n    StageConfig dummy_config;\n    dummy_config.set_name(\"dummy_input\");\n    registry_.AddStage(\"dummy_input\",\n                       std::make_unique<DummyStage>(input_queue_.get()));\n\n    // Create the shuffler config.\n    SimpleChunkExtractorConfig config;\n    config.set_input(\"dummy_input\");\n    config.set_queue_capacity(10);\n\n    // Create the shuffler stage.\n    shuffler_ = std::make_unique<SimpleChunkExtractor>(config, registry_);\n  }\n\n  void TearDown() override {\n    if (shuffler_) {\n      shuffler_->Stop();\n    }\n    input_queue_->Close();\n  }\n\n  // Helper class to provide a dummy stage for the registry.\n  class DummyStage : public Stage {\n   public:\n    explicit DummyStage(QueueBase* queue) : queue_(queue) {}\n    void Start() override {}\n    void Stop() override {}\n    StageMetricProto FlushMetrics() override { return {}; }\n    QueueBase* GetOutput(std::string_view name = \"\") override {\n      (void)name;\n      return queue_;\n    }\n\n   private:\n    QueueBase* queue_;\n  };\n\n  StageRegistry registry_;\n  std::unique_ptr<Queue<ChunkSourceWithPhase>> input_queue_;\n  std::unique_ptr<SimpleChunkExtractor> shuffler_;\n};\n\nTEST_F(SimpleChunkExtractorTest, ProcessesSingleSource) {\n  shuffler_->Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Send a chunk source with 5 chunks.\n  auto source = std::make_unique<MockChunkSource>(\"source1\", 5);\n  producer.Put({.source = std::move(source),\n                .message_type = FilePathProvider::MessageType::kFile});\n\n  // Close input to signal completion.\n  input_queue_->Close();\n\n  // Collect all output chunks.\n  auto* output = static_cast<Queue<TrainingChunk>*>(shuffler_->GetOutput());\n  std::vector<TrainingChunk> chunks;\n  while (true) {\n    try {\n      chunks.push_back(output->Get());\n    } catch (const QueueClosedException&) {\n      break;\n    }\n  }\n\n  // Should receive exactly 5 chunks.\n  EXPECT_EQ(chunks.size(), 5);\n\n  // All chunks should have the same sort_key.\n  for (const auto& chunk : chunks) {\n    EXPECT_EQ(chunk.sort_key, \"source1\");\n    EXPECT_EQ(chunk.frames.size(), 10);  // 10 frames per chunk.\n  }\n\n  // Check that all chunk indices are present (though order is shuffled).\n  std::vector<size_t> indices;\n  for (const auto& chunk : chunks) {\n    indices.push_back(chunk.index_within_sort_key);\n  }\n  std::sort(indices.begin(), indices.end());\n  EXPECT_THAT(indices, ::testing::ElementsAre(0, 1, 2, 3, 4));\n}\n\nTEST_F(SimpleChunkExtractorTest, ProcessesMultipleSources) {\n  shuffler_->Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Send two chunk sources.\n  producer.Put({.source = std::make_unique<MockChunkSource>(\"source1\", 3),\n                .message_type = FilePathProvider::MessageType::kFile});\n  producer.Put({.source = std::make_unique<MockChunkSource>(\"source2\", 2),\n                .message_type = FilePathProvider::MessageType::kFile});\n\n  input_queue_->Close();\n\n  // Collect all output chunks.\n  auto* output = static_cast<Queue<TrainingChunk>*>(shuffler_->GetOutput());\n  std::vector<TrainingChunk> chunks;\n  while (true) {\n    try {\n      chunks.push_back(output->Get());\n    } catch (const QueueClosedException&) {\n      break;\n    }\n  }\n\n  // Should receive 3 + 2 = 5 chunks total.\n  EXPECT_EQ(chunks.size(), 5);\n\n  // Count chunks per source.\n  size_t source1_count = 0;\n  size_t source2_count = 0;\n  for (const auto& chunk : chunks) {\n    if (chunk.sort_key == \"source1\") {\n      ++source1_count;\n    } else if (chunk.sort_key == \"source2\") {\n      ++source2_count;\n    }\n  }\n\n  EXPECT_EQ(source1_count, 3);\n  EXPECT_EQ(source2_count, 2);\n}\n\nTEST_F(SimpleChunkExtractorTest, SkipsNonFileMessages) {\n  shuffler_->Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Send a non-file message.\n  producer.Put(\n      {.source = nullptr,\n       .message_type = FilePathProvider::MessageType::kInitialScanComplete});\n\n  // Send a file message.\n  producer.Put({.source = std::make_unique<MockChunkSource>(\"source1\", 2),\n                .message_type = FilePathProvider::MessageType::kFile});\n\n  input_queue_->Close();\n\n  // Collect all output chunks.\n  auto* output = static_cast<Queue<TrainingChunk>*>(shuffler_->GetOutput());\n  std::vector<TrainingChunk> chunks;\n  while (true) {\n    try {\n      chunks.push_back(output->Get());\n    } catch (const QueueClosedException&) {\n      break;\n    }\n  }\n\n  // Should only receive 2 chunks from the file message.\n  EXPECT_EQ(chunks.size(), 2);\n}\n\nTEST_F(SimpleChunkExtractorTest, MetricsAreRecorded) {\n  shuffler_->Start();\n\n  auto producer = input_queue_->CreateProducer();\n  producer.Put({.source = std::make_unique<MockChunkSource>(\"source1\", 3),\n                .message_type = FilePathProvider::MessageType::kFile});\n\n  input_queue_->Close();\n\n  // Wait for processing to complete.\n  auto* output = static_cast<Queue<TrainingChunk>*>(shuffler_->GetOutput());\n  while (true) {\n    try {\n      output->Get();\n    } catch (const QueueClosedException&) {\n      break;\n    }\n  }\n\n  // Flush metrics.\n  auto metrics = shuffler_->FlushMetrics();\n\n  EXPECT_GT(metrics.count_metrics_size(), 0);\n\n  // Check that chunks_processed metric exists.\n  bool found_chunks_processed = false;\n  for (const auto& metric : metrics.count_metrics()) {\n    if (metric.name() == \"chunks_processed\") {\n      EXPECT_EQ(metric.count(), 3);\n      found_chunks_processed = true;\n    }\n  }\n  EXPECT_TRUE(found_chunks_processed);\n}\n\n}  // namespace\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/stage.cc",
    "content": "#include \"loader/stages/stage.h\"\n\n#include <absl/algorithm/container.h>\n\nnamespace lczero {\nnamespace training {\n\nvoid StageRegistry::AddStage(std::string_view stage_name,\n                             std::unique_ptr<Stage> stage) {\n  if (absl::c_find_if(stages_, [&](const auto& pair) {\n        return pair.first == stage_name;\n      }) != stages_.end()) {\n    throw std::runtime_error(\n        absl::StrCat(\"Duplicate stage name detected: \", stage_name));\n  }\n\n  stages_.emplace_back(stage_name, std::move(stage));\n}\n\nQueueBase* StageRegistry::GetStageOutput(std::string_view stage_name) const {\n  auto [actual_stage_name, output_name] = [&stage_name]() {\n    size_t dot_pos = stage_name.find('.');\n    return dot_pos == std::string_view::npos\n               ? std::pair{stage_name, std::string_view{}}\n               : std::pair{stage_name.substr(0, dot_pos),\n                           stage_name.substr(dot_pos + 1)};\n  }();\n\n  auto it = absl::c_find_if(stages_, [&](const auto& pair) {\n    return pair.first == actual_stage_name;\n  });\n  return it != stages_.end() ? it->second->GetOutput(output_name) : nullptr;\n}\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/loader/stages/stage.h",
    "content": "#pragma once\n\n#include <algorithm>\n#include <optional>\n#include <stdexcept>\n#include <string>\n#include <string_view>\n#include <utility>\n#include <vector>\n\n#include \"absl/strings/str_cat.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/stage_control.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Base interface implemented by all loader stages.\nclass Stage {\n public:\n  virtual ~Stage() = default;\n\n  // Starts background workers owned by the stage.\n  virtual void Start() = 0;\n\n  // Requests the stage to stop and join background work.\n  virtual void Stop() = 0;\n\n  // Flushes stage-specific metrics and returns a snapshot.\n  virtual StageMetricProto FlushMetrics() = 0;\n\n  // Returns the output queue for downstream stages.\n  virtual QueueBase* GetOutput(std::string_view name = \"\") = 0;\n\n  // Sets the input queues for this stage. Called after construction but before\n  // Start().\n  virtual void SetInputs(absl::Span<QueueBase* const> inputs) = 0;\n\n  // Handles control-plane messages specific to the stage.\n  virtual std::optional<StageControlResponse> Control(\n      const StageControlRequest& request) {\n    (void)request;\n    return std::nullopt;\n  }\n};\n\nclass StageRegistry {\n public:\n  // Registers a new stage with the given name and takes ownership of it.\n  void AddStage(std::string_view stage_name, std::unique_ptr<Stage> stage);\n\n  // Returns the output queue for the specified stage.\n  // If stage_name contains a dot (e.g., \"stage.output\"), splits it into stage\n  // name and output name, passing the output name to Stage::GetOutput().\n  // Returns nullptr if the stage is not found.\n  QueueBase* GetStageOutput(std::string_view stage_name) const;\n\n  template <typename StageT>\n  Queue<StageT>* GetTypedStageOutput(std::string_view stage_name) const {\n    QueueBase* raw_queue = GetStageOutput(stage_name);\n    if (raw_queue == nullptr) return nullptr;\n    auto* typed_queue = dynamic_cast<Queue<StageT>*>(raw_queue);\n    if (!typed_queue) {\n      throw std::runtime_error(\n          absl::StrCat(\"Stage output type mismatch for stage: \", stage_name));\n    }\n    return typed_queue;\n  }\n\n  size_t size() const { return stages_.size(); }\n  const std::vector<std::pair<std::string, std::unique_ptr<Stage>>>& stages()\n      const {\n    return stages_;\n  }\n\n private:\n  std::vector<std::pair<std::string, std::unique_ptr<Stage>>> stages_;\n};\n\n// Helper to convert QueueConfig::OverflowBehavior to OverflowBehavior.\ninline OverflowBehavior ToOverflowBehavior(\n    QueueConfig::OverflowBehavior behavior) {\n  switch (behavior) {\n    case QueueConfig::BLOCK:\n      return OverflowBehavior::BLOCK;\n    case QueueConfig::DROP_NEW:\n      return OverflowBehavior::DROP_NEW;\n    case QueueConfig::KEEP_NEWEST:\n      return OverflowBehavior::KEEP_NEWEST;\n  }\n  throw std::runtime_error(absl::StrCat(\"Unknown OverflowBehavior value: \",\n                                        static_cast<int>(behavior)));\n}\n\n// Helper for stages that consume a single upstream queue.\ntemplate <typename ConfigT, typename InputT>\nclass SingleInputStage : virtual public Stage {\n public:\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (inputs.size() != 1) {\n      throw std::runtime_error(absl::StrCat(\n          \"SingleInputStage expects exactly 1 input, got \", inputs.size()));\n    }\n    auto* typed_queue = dynamic_cast<Queue<InputT>*>(inputs[0]);\n    if (!typed_queue) throw std::runtime_error(\"Input queue type mismatch\");\n    input_queue_ = typed_queue;\n  }\n\n protected:\n  explicit SingleInputStage(const ConfigT&) : input_queue_(nullptr) {}\n\n  Queue<InputT>* input_queue() { return input_queue_; }\n\n private:\n  Queue<InputT>* input_queue_;\n};\n\n// Helper for stages that produce a single output queue.\ntemplate <typename OutputT>\nclass SingleOutputStage : virtual public Stage {\n public:\n  Queue<OutputT>* output_queue() { return &output_queue_; }\n\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    if (name != output_name_) {\n      throw std::runtime_error(absl::StrCat(\"Output name '\", name,\n                                            \"' does not match configured '\",\n                                            output_name_, \"'\"));\n    }\n    return &output_queue_;\n  }\n\n protected:\n  explicit SingleOutputStage(const QueueConfig& config)\n      : output_name_(config.name()),\n        output_queue_(config.queue_capacity(),\n                      ToOverflowBehavior(config.overflow_behavior())) {}\n\n private:\n  std::string output_name_;\n  Queue<OutputT> output_queue_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/stage_factory.cc",
    "content": "#include \"loader/stages/stage_factory.h\"\n\n#include <stdexcept>\n#include <string>\n\n#include \"loader/stages/chunk_rescorer.h\"\n#include \"loader/stages/chunk_source_loader.h\"\n#include \"loader/stages/chunk_source_splitter.h\"\n#include \"loader/stages/chunk_unpacker.h\"\n#include \"loader/stages/file_path_provider.h\"\n#include \"loader/stages/join_stage.h\"\n#include \"loader/stages/shuffling_chunk_pool.h\"\n#include \"loader/stages/shuffling_frame_sampler.h\"\n#include \"loader/stages/simple_chunk_extractor.h\"\n#include \"loader/stages/tensor_generator.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\nint CountStageConfigs(const StageConfig& config) {\n  return static_cast<int>(config.has_file_path_provider()) +\n         static_cast<int>(config.has_chunk_source_loader()) +\n         static_cast<int>(config.has_shuffling_chunk_pool()) +\n         static_cast<int>(config.has_chunk_rescorer()) +\n         static_cast<int>(config.has_chunk_unpacker()) +\n         static_cast<int>(config.has_shuffling_frame_sampler()) +\n         static_cast<int>(config.has_tensor_generator()) +\n         static_cast<int>(config.has_chunk_source_splitter()) +\n         static_cast<int>(config.has_simple_chunk_extractor()) +\n         static_cast<int>(config.has_join_positions());\n}\n\n}  // namespace\n\nstd::unique_ptr<Stage> CreateStage(const StageConfig& config) {\n  if (CountStageConfigs(config) != 1) {\n    throw std::runtime_error(\n        \"StageConfig must have exactly one stage-specific config set.\");\n  }\n\n  if (config.has_file_path_provider()) {\n    return std::make_unique<FilePathProvider>(config.file_path_provider());\n  }\n  if (config.has_chunk_source_loader()) {\n    return std::make_unique<ChunkSourceLoader>(config.chunk_source_loader());\n  }\n  if (config.has_shuffling_chunk_pool()) {\n    return std::make_unique<ShufflingChunkPool>(config.shuffling_chunk_pool());\n  }\n  if (config.has_chunk_rescorer()) {\n    return std::make_unique<ChunkRescorer>(config.chunk_rescorer());\n  }\n  if (config.has_chunk_unpacker()) {\n    return std::make_unique<ChunkUnpacker>(config.chunk_unpacker());\n  }\n  if (config.has_shuffling_frame_sampler()) {\n    return std::make_unique<ShufflingFrameSampler>(\n        config.shuffling_frame_sampler());\n  }\n  if (config.has_tensor_generator()) {\n    return std::make_unique<TensorGenerator>(config.tensor_generator());\n  }\n  if (config.has_chunk_source_splitter()) {\n    return std::make_unique<ChunkSourceSplitter>(\n        config.chunk_source_splitter());\n  }\n  if (config.has_simple_chunk_extractor()) {\n    return std::make_unique<SimpleChunkExtractor>(\n        config.simple_chunk_extractor());\n  }\n  if (config.has_join_positions()) {\n    return std::make_unique<JoinPositions>(config.join_positions());\n  }\n\n  throw std::runtime_error(\n      \"StageConfig did not contain a recognized stage configuration.\");\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/stage_factory.h",
    "content": "#pragma once\n\n#include <memory>\n\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nstd::unique_ptr<Stage> CreateStage(const StageConfig& config);\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/stage_factory_test.cc",
    "content": "#include \"loader/stages/stage_factory.h\"\n\n#include <gtest/gtest.h>\n\n#include <stdexcept>\n\nnamespace lczero {\nnamespace training {\n\nTEST(StageFactoryTest, CreatesFilePathProviderStage) {\n  StageConfig config;\n  config.mutable_file_path_provider()->set_directory(\".\");\n\n  auto stage = CreateStage(config);\n\n  ASSERT_NE(stage, nullptr);\n  EXPECT_NE(stage->GetOutput(), nullptr);\n}\n\nTEST(StageFactoryTest, ThrowsWhenNoStageConfigSet) {\n  StageConfig config;\n\n  EXPECT_THROW(CreateStage(config), std::runtime_error);\n}\n\nTEST(StageFactoryTest, ThrowsWhenMultipleStageConfigsSet) {\n  StageConfig config;\n  config.mutable_file_path_provider()->set_directory(\".\");\n  config.mutable_tensor_generator();\n\n  EXPECT_THROW(CreateStage(config), std::runtime_error);\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/tensor_generator.cc",
    "content": "// ABOUTME: Implementation of TensorGenerator stage for training pipeline.\n// ABOUTME: Converts V6TrainingData frames to batched tensors for training.\n\n#include \"loader/stages/tensor_generator.h\"\n\n#include <cstring>\n#include <limits>\n#include <memory>\n#include <utility>\n#include <vector>\n\n#include \"absl/algorithm/container.h\"\n#include \"absl/log/log.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\nnamespace training {\n\nTensorGenerator::TensorGenerator(const TensorGeneratorConfig& config)\n    : SingleInputStage<TensorGeneratorConfig, InputType>(config),\n      SingleOutputStage<OutputType>(config.output()),\n      batch_size_(config.batch_size()),\n      thread_pool_(config.threads(), ThreadPoolOptions{}) {\n  LOG(INFO) << \"Initializing TensorGenerator with \" << config.threads()\n            << \" threads, batch size \" << config.batch_size();\n\n  // Initialize thread contexts but don't start worker threads yet.\n  thread_contexts_.reserve(config.threads());\n  for (size_t i = 0; i < config.threads(); ++i) {\n    thread_contexts_.push_back(std::make_unique<ThreadContext>());\n  }\n}\n\nTensorGenerator::~TensorGenerator() { Stop(); }\n\nvoid TensorGenerator::Start() {\n  LOG(INFO) << \"Starting TensorGenerator worker threads.\";\n  for (size_t i = 0; i < thread_contexts_.size(); ++i) {\n    thread_pool_.Enqueue([this, i](std::stop_token stop_token) {\n      Worker(stop_token, thread_contexts_[i].get());\n    });\n  }\n}\n\nvoid TensorGenerator::Stop() {\n  if (thread_pool_.stop_token().stop_requested()) return;\n\n  LOG(INFO) << \"Stopping TensorGenerator.\";\n  thread_pool_.Shutdown();\n  output_queue()->Close();\n  LOG(INFO) << \"TensorGenerator stopped.\";\n}\n\nvoid TensorGenerator::Worker(std::stop_token stop_token,\n                             ThreadContext* context) {\n  auto producer = output_queue()->CreateProducer();\n  std::vector<FrameType> batch;\n  batch.reserve(batch_size_);\n\n  try {\n    while (true) {\n      // Collect frames for a batch.\n      batch.clear();\n      for (size_t i = 0; i < batch_size_; ++i) {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        batch.push_back(input_queue()->Get(stop_token));\n      }\n\n      // Convert batch to tensors.\n      TensorTuple tensors = ConvertFramesToTensors(batch);\n      {\n        LoadMetricPauser pauser(context->load_metric_updater);\n        producer.Put(std::move(tensors), stop_token);\n      }\n    }\n  } catch (const QueueClosedException&) {\n    LOG(INFO) << \"TensorGenerator worker stopping, queue closed.\";\n  } catch (const QueueRequestCancelled&) {\n    LOG(INFO) << \"TensorGenerator worker stopping, request cancelled.\";\n  }\n}\n\nTensorTuple TensorGenerator::ConvertFramesToTensors(\n    const std::vector<FrameType>& frames) {\n  const size_t batch_size = frames.size();\n  constexpr size_t kNumPlanes = 112;\n  constexpr size_t kNumPolicyMoves = 1858;\n  constexpr size_t kNumValueTypes = 6;\n  constexpr size_t kValuesPerType = 3;\n\n  TensorTuple result;\n  result.reserve(3);\n\n  // Index 0: Input planes (batch_size, 112, 8, 8)\n  auto planes_tensor = std::make_unique<TypedTensor<float>>(\n      std::initializer_list<size_t>{batch_size, kNumPlanes, 8, 8});\n  ProcessPlanes(frames, *planes_tensor);\n  result.push_back(std::move(planes_tensor));\n\n  // Index 1: Probabilities (batch_size, 1858)\n  auto probs_tensor = std::make_unique<TypedTensor<float>>(\n      std::initializer_list<size_t>{batch_size, kNumPolicyMoves});\n  for (size_t i = 0; i < batch_size; ++i) {\n    auto probs_slice = probs_tensor->slice({static_cast<ssize_t>(i)});\n    std::memcpy(probs_slice.data(), frames[i].probabilities,\n                kNumPolicyMoves * sizeof(float));\n  }\n  result.push_back(std::move(probs_tensor));\n\n  // Index 2: Values (batch_size, 6, 3) with [q, d, m] for each type.\n  // [0]: result, [1]: best, [2]: played, [3]: orig, [4]: root, [5]: st\n  auto values_tensor =\n      std::make_unique<TypedTensor<float>>(std::initializer_list<size_t>{\n          batch_size, kNumValueTypes, kValuesPerType});\n  for (size_t i = 0; i < batch_size; ++i) {\n    const auto& frame = frames[i];\n    auto batch_slice = values_tensor->slice({static_cast<ssize_t>(i)});\n\n    // Index 0: result [result_q, result_d, plies_left]\n    auto result_slice = batch_slice.subspan(0 * kValuesPerType, kValuesPerType);\n    result_slice[0] = frame.result_q;\n    result_slice[1] = frame.result_d;\n    result_slice[2] = frame.plies_left;\n\n    // Index 1: best [best_q, best_d, best_m]\n    auto best_slice = batch_slice.subspan(1 * kValuesPerType, kValuesPerType);\n    best_slice[0] = frame.best_q;\n    best_slice[1] = frame.best_d;\n    best_slice[2] = frame.best_m;\n\n    // Index 2: played [played_q, played_d, played_m]\n    auto played_slice = batch_slice.subspan(2 * kValuesPerType, kValuesPerType);\n    played_slice[0] = frame.played_q;\n    played_slice[1] = frame.played_d;\n    played_slice[2] = frame.played_m;\n\n    // Index 3: orig [orig_q, orig_d, orig_m] (may be NaN)\n    auto orig_slice = batch_slice.subspan(3 * kValuesPerType, kValuesPerType);\n    orig_slice[0] = frame.orig_q;\n    orig_slice[1] = frame.orig_d;\n    orig_slice[2] = frame.orig_m;\n\n    // Index 4: root [root_q, root_d, root_m]\n    auto root_slice = batch_slice.subspan(4 * kValuesPerType, kValuesPerType);\n    root_slice[0] = frame.root_q;\n    root_slice[1] = frame.root_d;\n    root_slice[2] = frame.root_m;\n\n    // Index 5: st [q_st, d_st, NaN]\n    auto st_slice = batch_slice.subspan(5 * kValuesPerType, kValuesPerType);\n    st_slice[0] = frame.q_st;\n    st_slice[1] = frame.d_st;\n    st_slice[2] = std::numeric_limits<float>::quiet_NaN();\n  }\n  result.push_back(std::move(values_tensor));\n\n  return result;\n}\n\nvoid TensorGenerator::ProcessPlanes(const std::vector<FrameType>& frames,\n                                    TypedTensor<float>& planes_tensor) {\n  const size_t batch_size = frames.size();\n\n  for (size_t i = 0; i < batch_size; ++i) {\n    const auto& frame = frames[i];\n    auto batch_slice = planes_tensor.slice({static_cast<ssize_t>(i)});\n\n    // Process first 104 planes from frame.planes (each uint64_t represents 64\n    // bits).\n    for (ssize_t plane = 0; plane < 104; ++plane) {\n      auto plane_slice = batch_slice.subspan(plane * 64, 64);\n      uint64_t plane_bits = frame.planes[plane];\n\n      for (ssize_t square = 0; square < 64; ++square) {\n        // XOR with 7 remaps the index within each byte from 0..7 to 7..0.\n        plane_slice[square] =\n            static_cast<float>((plane_bits >> (square ^ 7)) & 1);\n      }\n    }\n\n    // Add 8 additional planes for metadata (planes 104-111).\n    const std::pair<ssize_t, float> meta_planes[] = {\n        {104, static_cast<float>(frame.castling_us_ooo)},\n        {105, static_cast<float>(frame.castling_us_oo)},\n        {106, static_cast<float>(frame.castling_them_ooo)},\n        {107, static_cast<float>(frame.castling_them_oo)},\n        {108, static_cast<float>(frame.side_to_move_or_enpassant)},\n        {109, static_cast<float>(frame.rule50_count) / 99.0f},\n        {110, 0.0f},  // All zeros (constant plane).\n        {111, 1.0f},  // All ones (constant plane).\n    };\n\n    for (const auto& [plane_num, value] : meta_planes) {\n      auto plane_slice = batch_slice.subspan(plane_num * 64, 64);\n      absl::c_fill(plane_slice, value);\n    }\n  }\n}\n\nStageMetricProto TensorGenerator::FlushMetrics() {\n  StageMetricProto stage_metric;\n  LoadMetricProto aggregated_load;\n  aggregated_load.set_name(\"load\");\n  for (const auto& context : thread_contexts_) {\n    UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());\n  }\n  *stage_metric.add_load_metrics() = std::move(aggregated_load);\n  *stage_metric.add_queue_metrics() =\n      MetricsFromQueue(\"output\", *output_queue());\n  return stage_metric;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/tensor_generator.h",
    "content": "// ABOUTME: Stage that converts FrameType frames into tensor batches.\n// ABOUTME: Produces TrainingTensors with tensors for training pipeline.\n#pragma once\n\n#include <atomic>\n#include <cstddef>\n#include <memory>\n#include <stop_token>\n#include <vector>\n\n#include \"loader/data_loader.h\"\n#include \"loader/data_loader_metrics.h\"\n#include \"loader/frame_type.h\"\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/queue.h\"\n#include \"utils/tensor.h\"\n#include \"utils/thread_pool.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Worker pool that converts FrameType frames into tensor batches.\n// Takes individual FrameType frames as input and outputs TensorTuple\n// containing batched tensors in the format required for training.\nclass TensorGenerator\n    : public SingleInputStage<TensorGeneratorConfig, FrameType>,\n      public SingleOutputStage<TensorTuple> {\n public:\n  using InputType = FrameType;\n  using OutputType = TensorTuple;\n\n  explicit TensorGenerator(const TensorGeneratorConfig& config);\n  ~TensorGenerator();\n\n  void Start() override;\n  void Stop() override;\n  StageMetricProto FlushMetrics() override;\n\n private:\n  struct ThreadContext {\n    LoadMetricUpdater load_metric_updater;\n  };\n\n  void Worker(std::stop_token stop_token, ThreadContext* context);\n  TensorTuple ConvertFramesToTensors(const std::vector<FrameType>& frames);\n  void ProcessPlanes(const std::vector<FrameType>& frames,\n                     TypedTensor<float>& planes_tensor);\n\n  size_t batch_size_;\n  // thread_contexts_ must be declared before thread_pool_ to ensure\n  // thread_pool_ is destroyed first (stopping threads before contexts).\n  std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;\n  ThreadPool thread_pool_;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/tensor_generator_test.cc",
    "content": "// ABOUTME: Unit tests for TensorGenerator stage in training pipeline.\n// ABOUTME: Tests tensor conversion, batching, and data format correctness.\n\n#include \"loader/stages/tensor_generator.h\"\n\n#include <cstring>\n#include <memory>\n#include <vector>\n\n#include \"gtest/gtest.h\"\n#include \"libs/lc0/src/trainingdata/trainingdata_v6.h\"\n#include \"utils/queue.h\"\n#include \"utils/tensor.h\"\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\ntemplate <typename T>\nclass PassthroughStage : public Stage {\n public:\n  explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}\n\n  void Start() override {}\n  void Stop() override {}\n  StageMetricProto FlushMetrics() override { return StageMetricProto(); }\n  QueueBase* GetOutput(std::string_view name = \"\") override {\n    (void)name;\n    return queue_;\n  }\n  void SetInputs(absl::Span<QueueBase* const> inputs) override {\n    if (!inputs.empty()) {\n      throw std::runtime_error(\"PassthroughStage expects no inputs\");\n    }\n  }\n\n private:\n  Queue<T>* queue_;\n};\n\n}  // namespace\n\nclass TensorGeneratorTest : public ::testing::Test {\n protected:\n  void SetUp() override {\n    input_queue_ = std::make_unique<Queue<FrameType>>(100);\n    config_.set_batch_size(4);\n    config_.set_threads(1);\n    config_.mutable_output()->set_queue_capacity(10);\n  }\n\n  FrameType CreateTestFrame() {\n    FrameType frame{};\n    std::memset(&frame, 0, sizeof(frame));\n\n    frame.version = 6;\n    frame.input_format = 3;\n\n    // Fill probabilities with test values.\n    for (ssize_t i = 0; i < 1858; ++i) {\n      frame.probabilities[i] = static_cast<float>(i) / 1858.0f;\n    }\n\n    // Fill planes with test pattern.\n    for (ssize_t i = 0; i < 104; ++i) {\n      frame.planes[i] = 0x0F0F0F0F0F0F0F0FULL + i;  // Test pattern\n    }\n\n    // Set castling rights.\n    frame.castling_us_ooo = 1;\n    frame.castling_us_oo = 0;\n    frame.castling_them_ooo = 1;\n    frame.castling_them_oo = 1;\n\n    // Set other fields.\n    frame.side_to_move_or_enpassant = 1;\n    frame.rule50_count = 50;\n\n    // Set Q and D values.\n    frame.result_q = 0.5f;\n    frame.result_d = 0.2f;\n    frame.best_q = 0.3f;\n    frame.best_d = 0.1f;\n    frame.best_m = 42.5f;\n    frame.plies_left = 42.5f;\n\n    return frame;\n  }\n\n  void VerifyTensorTuple(const TensorTuple& tensors,\n                         const std::vector<FrameType>& frames) {\n    const size_t batch_size = frames.size();\n\n    // Verify tuple has 3 elements\n    ASSERT_EQ(tensors.size(), 3);\n\n    // Verify input tensor: (batch_size, 112, 8, 8)\n    const auto* planes_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[0].get());\n    ASSERT_NE(planes_tensor, nullptr);\n    EXPECT_EQ(planes_tensor->shape().size(), 4);\n    EXPECT_EQ(planes_tensor->shape()[0], batch_size);\n    EXPECT_EQ(planes_tensor->shape()[1], 112);\n    EXPECT_EQ(planes_tensor->shape()[2], 8);\n    EXPECT_EQ(planes_tensor->shape()[3], 8);\n\n    // Verify probabilities tensor: (batch_size, 1858)\n    const auto* probs_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[1].get());\n    ASSERT_NE(probs_tensor, nullptr);\n    EXPECT_EQ(probs_tensor->shape().size(), 2);\n    EXPECT_EQ(probs_tensor->shape()[0], batch_size);\n    EXPECT_EQ(probs_tensor->shape()[1], 1858);\n\n    // Verify values tensor: (batch_size, 6, 3)\n    const auto* values_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[2].get());\n    ASSERT_NE(values_tensor, nullptr);\n    EXPECT_EQ(values_tensor->shape().size(), 3);\n    EXPECT_EQ(values_tensor->shape()[0], batch_size);\n    EXPECT_EQ(values_tensor->shape()[1], 6);\n    EXPECT_EQ(values_tensor->shape()[2], 3);\n  }\n\n  void VerifyTensorData(const TensorTuple& tensors,\n                        const std::vector<FrameType>& frames) {\n    const size_t batch_size = frames.size();\n    const auto* planes_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[0].get());\n    const auto* probs_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[1].get());\n    const auto* values_tensor =\n        dynamic_cast<const TypedTensor<float>*>(tensors[2].get());\n\n    for (size_t i = 0; i < batch_size; ++i) {\n      const auto& frame = frames[i];\n\n      // Verify probabilities data.\n      auto probs_slice = probs_tensor->slice({static_cast<ssize_t>(i)});\n      for (ssize_t j = 0; j < 1858; ++j) {\n        EXPECT_FLOAT_EQ(probs_slice[j], frame.probabilities[j]);\n      }\n\n      // Verify values tensor [batch, 6, 3] with raw q/d/m values\n      // Index 0: result (q=0.5, d=0.2, m=42.5)\n      auto values_slice = values_tensor->slice({static_cast<ssize_t>(i)});\n      EXPECT_FLOAT_EQ(values_slice[0 * 3 + 0], 0.5f);   // result_q\n      EXPECT_FLOAT_EQ(values_slice[0 * 3 + 1], 0.2f);   // result_d\n      EXPECT_FLOAT_EQ(values_slice[0 * 3 + 2], 42.5f);  // result_m\n\n      // Index 1: best (q=0.3, d=0.1, m=42.5)\n      EXPECT_FLOAT_EQ(values_slice[1 * 3 + 0], 0.3f);   // best_q\n      EXPECT_FLOAT_EQ(values_slice[1 * 3 + 1], 0.1f);   // best_d\n      EXPECT_FLOAT_EQ(values_slice[1 * 3 + 2], 42.5f);  // best_m\n\n      // Verify planes data - check first few planes and meta planes.\n      auto planes_slice = planes_tensor->slice({static_cast<ssize_t>(i)});\n\n      // Check first plane (plane 0).\n      uint64_t expected_plane_0 = 0x0F0F0F0F0F0F0F0FULL;\n      for (ssize_t square = 0; square < 64; ++square) {\n        float expected =\n            static_cast<float>((expected_plane_0 >> (63 - square)) & 1);\n        EXPECT_FLOAT_EQ(planes_slice[square], expected);\n      }\n\n      // Check meta planes.\n      // Plane 104: castling_us_ooo = 1\n      for (ssize_t square = 104 * 64; square < 105 * 64; ++square) {\n        EXPECT_FLOAT_EQ(planes_slice[square], 1.0f);\n      }\n\n      // Plane 105: castling_us_oo = 0\n      for (ssize_t square = 105 * 64; square < 106 * 64; ++square) {\n        EXPECT_FLOAT_EQ(planes_slice[square], 0.0f);\n      }\n\n      // Plane 109: rule50_count = 50, should be 50/99\n      for (ssize_t square = 109 * 64; square < 110 * 64; ++square) {\n        EXPECT_FLOAT_EQ(planes_slice[square], 50.0f / 99.0f);\n      }\n\n      // Plane 110: all zeros\n      for (ssize_t square = 110 * 64; square < 111 * 64; ++square) {\n        EXPECT_FLOAT_EQ(planes_slice[square], 0.0f);\n      }\n\n      // Plane 111: all ones\n      for (ssize_t square = 111 * 64; square < 112 * 64; ++square) {\n        EXPECT_FLOAT_EQ(planes_slice[square], 1.0f);\n      }\n    }\n  }\n\n  std::unique_ptr<Queue<FrameType>> input_queue_;\n  TensorGeneratorConfig config_;\n};\n\nTEST_F(TensorGeneratorTest, GeneratesCorrectTensorShapes) {\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n  std::vector<FrameType> frames;\n  for (size_t i = 0; i < config_.batch_size(); ++i) {\n    frames.push_back(CreateTestFrame());\n    producer.Put(frames.back());\n  }\n  producer.Close();\n\n  auto tensors = generator.output_queue()->Get();\n  VerifyTensorTuple(tensors, frames);\n}\n\nTEST_F(TensorGeneratorTest, GeneratesCorrectTensorData) {\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n  std::vector<FrameType> frames;\n  for (size_t i = 0; i < config_.batch_size(); ++i) {\n    frames.push_back(CreateTestFrame());\n    producer.Put(frames.back());\n  }\n  producer.Close();\n\n  auto tensors = generator.output_queue()->Get();\n  VerifyTensorTuple(tensors, frames);\n  VerifyTensorData(tensors, frames);\n}\n\nTEST_F(TensorGeneratorTest, HandlesMultipleBatches) {\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  // Send two full batches.\n  std::vector<FrameType> all_frames;\n  for (ssize_t batch = 0; batch < 2; ++batch) {\n    for (size_t i = 0; i < config_.batch_size(); ++i) {\n      auto frame = CreateTestFrame();\n      frame.version = batch * 1000 + i;  // Unique version for each frame\n      all_frames.push_back(frame);\n      producer.Put(frame);\n    }\n  }\n  producer.Close();\n\n  // Get first batch.\n  auto tensors1 = generator.output_queue()->Get();\n  std::vector<FrameType> batch1_frames(\n      all_frames.begin(), all_frames.begin() + config_.batch_size());\n  VerifyTensorTuple(tensors1, batch1_frames);\n\n  // Get second batch.\n  auto tensors2 = generator.output_queue()->Get();\n  std::vector<FrameType> batch2_frames(\n      all_frames.begin() + config_.batch_size(), all_frames.end());\n  VerifyTensorTuple(tensors2, batch2_frames);\n\n  // No more batches should be available.\n  EXPECT_THROW(generator.output_queue()->Get(), QueueClosedException);\n}\n\nTEST_F(TensorGeneratorTest, HandlesDifferentBatchSizes) {\n  config_.set_batch_size(2);\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n  std::vector<FrameType> frames;\n  for (size_t i = 0; i < config_.batch_size(); ++i) {\n    frames.push_back(CreateTestFrame());\n    producer.Put(frames.back());\n  }\n  producer.Close();\n\n  auto tensors = generator.output_queue()->Get();\n  VerifyTensorTuple(tensors, frames);\n}\n\nTEST_F(TensorGeneratorTest, HandlesEmptyInput) {\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  // Close input queue without sending data.\n  input_queue_->Close();\n\n  // Should not output any tensors.\n  EXPECT_THROW(generator.output_queue()->Get(), QueueClosedException);\n}\n\nTEST_F(TensorGeneratorTest, VerifiesPlanesConversion) {\n  config_.set_batch_size(1);\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  FrameType frame = CreateTestFrame();\n  // Set specific bit pattern for plane 0.\n  frame.planes[0] = 0xAAAAAAAAAAAAAAAAULL;  // Alternating bits\n  // Set specific values for meta planes.\n  frame.castling_us_ooo = 1;\n  frame.castling_us_oo = 0;\n  frame.rule50_count = 75;\n\n  producer.Put(frame);\n  producer.Close();\n\n  auto tensors = generator.output_queue()->Get();\n  const auto* planes_tensor =\n      dynamic_cast<const TypedTensor<float>*>(tensors[0].get());\n\n  auto planes_slice = planes_tensor->slice({0});\n\n  // Verify plane 0 bit conversion.\n  for (ssize_t square = 0; square < 64; ++square) {\n    float expected =\n        static_cast<float>((0xAAAAAAAAAAAAAAAAULL >> (63 - square)) & 1);\n    EXPECT_FLOAT_EQ(planes_slice[square], expected)\n        << \"Mismatch at square \" << square;\n  }\n\n  // Verify rule50_count conversion: 75/99.\n  for (ssize_t square = 109 * 64; square < 110 * 64; ++square) {\n    EXPECT_FLOAT_EQ(planes_slice[square], 75.0f / 99.0f);\n  }\n}\n\nTEST_F(TensorGeneratorTest, VerifiesQDConversion) {\n  config_.set_batch_size(1);\n  TensorGenerator generator(config_);\n  generator.SetInputs({input_queue_.get()});\n  generator.Start();\n\n  auto producer = input_queue_->CreateProducer();\n\n  FrameType frame = CreateTestFrame();\n  // Test specific Q/D values.\n  frame.result_q = 0.4f;\n  frame.result_d = 0.3f;\n  frame.best_q = -0.2f;\n  frame.best_d = 0.1f;\n\n  producer.Put(frame);\n  producer.Close();\n\n  auto tensors = generator.output_queue()->Get();\n  const auto* values_tensor =\n      dynamic_cast<const TypedTensor<float>*>(tensors[2].get());\n\n  auto values_slice = values_tensor->slice({0});\n\n  // Verify result values: q=0.4, d=0.3 (raw values, no WDL conversion)\n  EXPECT_FLOAT_EQ(values_slice[0 * 3 + 0], 0.4f);  // result_q\n  EXPECT_FLOAT_EQ(values_slice[0 * 3 + 1], 0.3f);  // result_d\n\n  // Verify best values: q=-0.2, d=0.1 (raw values, no WDL conversion)\n  EXPECT_FLOAT_EQ(values_slice[1 * 3 + 0], -0.2f);  // best_q\n  EXPECT_FLOAT_EQ(values_slice[1 * 3 + 1], 0.1f);   // best_d\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/loader/stages/training_chunk.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <string>\n#include <vector>\n\n#include \"loader/frame_type.h\"\n\nnamespace lczero {\nnamespace training {\n\nstruct TrainingChunk {\n  std::vector<FrameType> frames;\n  std::string sort_key;\n  size_t index_within_sort_key = 0;\n  size_t global_index = 0;\n  uint32_t use_count = 0;\n};\n\nstruct CacheRequest {\n  size_t global_index = 0;\n  uint16_t next_use = 0;\n  std::vector<FrameType> items;\n};\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/tools/dump_chunk_main.cc",
    "content": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/str_format.h>\n#include <absl/strings/string_view.h>\n#include <zlib.h>\n\n#include <algorithm>\n#include <cstddef>\n#include <cstdint>\n#include <iostream>\n#include <iterator>\n#include <string>\n\n#include \"trainingdata/trainingdata_v6.h\"\n#include \"utils/training_data_printer.h\"\n\nABSL_FLAG(std::string, chunk_path, \"\", \"Path to the chunk file (.gz) to dump.\");\nABSL_FLAG(int64_t, max_entries, -1,\n          \"Maximum number of entries to print. -1 prints all entries.\");\nABSL_FLAG(int64_t, float_values_per_line, 8,\n          \"Number of floating point values per output line.\");\nABSL_FLAG(int64_t, plane_values_per_line, 4,\n          \"Number of plane values per output line.\");\n\nnamespace lczero {\nnamespace training {\n\nnamespace {\n\nusing ::lczero::training::FrameType;\nusing ::lczero::training::PrintTrainingDataEntry;\n\nvoid DumpChunk(const std::string& path, int64_t max_entries,\n               int64_t float_per_line, int64_t plane_per_line) {\n  gzFile file = gzopen(path.c_str(), \"rb\");\n  if (file == nullptr) {\n    LOG(FATAL) << \"Failed to open chunk file: \" << path;\n  }\n\n  size_t index = 0;\n  while (true) {\n    FrameType entry;\n    const int bytes_read = gzread(file, &entry, sizeof(entry));\n    if (bytes_read == 0) {\n      break;\n    }\n    if (bytes_read < 0) {\n      int errnum = 0;\n      const char* error_message = gzerror(file, &errnum);\n      gzclose(file);\n      LOG(FATAL) << \"Error while reading chunk: \" << error_message;\n    }\n    if (bytes_read != sizeof(entry)) {\n      gzclose(file);\n      LOG(FATAL) << \"Unexpected chunk size. Expected \" << sizeof(entry)\n                 << \" bytes, got \" << bytes_read << \".\";\n    }\n\n    const std::string header = absl::StrFormat(\"Entry %zu:\", index);\n    PrintTrainingDataEntry(entry, header, float_per_line, plane_per_line);\n    ++index;\n\n    if (max_entries >= 0 && static_cast<int64_t>(index) >= max_entries) {\n      break;\n    }\n  }\n\n  gzclose(file);\n  LOG(INFO) << \"Printed \" << index << \" entries.\";\n}\n\n}  // namespace\n\n}  // namespace training\n}  // namespace lczero\n\nint main(int argc, char** argv) {\n  absl::ParseCommandLine(argc, argv);\n  absl::InitializeLog();\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  const std::string chunk_path = absl::GetFlag(FLAGS_chunk_path);\n  if (chunk_path.empty()) {\n    LOG(FATAL) << \"--chunk_path flag is required.\";\n  }\n\n  const int64_t max_entries = absl::GetFlag(FLAGS_max_entries);\n  const int64_t float_per_line = absl::GetFlag(FLAGS_float_values_per_line);\n  const int64_t plane_per_line = absl::GetFlag(FLAGS_plane_values_per_line);\n\n  lczero::training::DumpChunk(chunk_path, max_entries, float_per_line,\n                              plane_per_line);\n  return 0;\n}\n"
  },
  {
    "path": "csrc/tools/filter_chunks_main.cc",
    "content": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/ascii.h>\n#include <absl/strings/match.h>\n#include <absl/strings/numbers.h>\n#include <absl/strings/str_cat.h>\n#include <absl/strings/str_split.h>\n#include <absl/strings/strip.h>\n#include <absl/types/span.h>\n#include <zlib.h>\n\n#include <algorithm>\n#include <cstdint>\n#include <cstring>\n#include <filesystem>\n#include <limits>\n#include <memory>\n#include <optional>\n#include <string>\n#include <string_view>\n#include <thread>\n#include <utility>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/chunk_source/tar_chunk_source.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n\nABSL_FLAG(std::string, input_dir, \".\", \"Directory to scan for .tar files.\");\nABSL_FLAG(std::string, output_dir, \".\",\n          \"Directory where matching chunks will be written.\");\nABSL_FLAG(std::string, plane_values, \"\",\n          \"Comma separated list of plane values (decimal or hex).\");\n\nnamespace {\n\nnamespace fs = std::filesystem;\n\nusing ::lczero::training::ChunkSource;\nusing ::lczero::training::ChunkSourceLoaderConfig;\nusing ::lczero::training::FrameType;\nusing ::lczero::training::TarChunkSource;\n\nstd::vector<fs::path> CollectTarFiles(const fs::path& directory) {\n  std::vector<fs::path> files;\n  for (const auto& entry : fs::directory_iterator(directory)) {\n    const fs::path& path = entry.path();\n    if (entry.is_regular_file() && path.extension() == \".tar\") {\n      files.push_back(path);\n    }\n  }\n  absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) {\n    return lhs.filename() < rhs.filename();\n  });\n  return files;\n}\n\nstd::vector<uint64_t> ParsePlaneValues(absl::string_view value_list) {\n  std::vector<uint64_t> result;\n  if (value_list.empty()) {\n    LOG(FATAL) << \"--plane_values flag must not be empty.\";\n  }\n  for (absl::string_view token :\n       absl::StrSplit(value_list, ',', absl::SkipWhitespace())) {\n    token = absl::StripAsciiWhitespace(token);\n    if (token.empty()) continue;\n\n    uint64_t value = 0;\n    if (absl::StartsWithIgnoreCase(token, \"0x\")) {\n      const absl::string_view hex_part = token.substr(2);\n      if (hex_part.empty() || !absl::SimpleHexAtoi(hex_part, &value)) {\n        LOG(FATAL) << \"Invalid hex plane value: \" << token;\n      }\n    } else if (!absl::SimpleAtoi(token, &value)) {\n      LOG(FATAL) << \"Invalid decimal plane value: \" << token;\n    }\n    result.push_back(value);\n  }\n  if (result.empty()) {\n    LOG(FATAL) << \"No plane values were parsed.\";\n  }\n  return result;\n}\n\nbool PlanesMatch(const FrameType& entry, absl::Span<const uint64_t> expected) {\n  if (expected.size() > std::size(entry.planes)) return false;\n  const size_t bytes = expected.size() * sizeof(uint64_t);\n  return std::memcmp(entry.planes, expected.data(), bytes) == 0;\n}\n\nstd::optional<size_t> FindMatchingFrameIndex(\n    const std::vector<FrameType>& chunk, absl::Span<const uint64_t> expected) {\n  for (size_t frame = 0; frame < chunk.size(); ++frame) {\n    if (PlanesMatch(chunk[frame], expected)) return frame;\n  }\n  return std::nullopt;\n}\n\nvoid WriteChunk(const fs::path& output_dir, absl::string_view base_name,\n                size_t index, size_t frame_index,\n                const std::vector<FrameType>& chunk) {\n  fs::create_directories(output_dir);\n  const fs::path output_path =\n      output_dir / absl::StrCat(base_name, \"_\", index, \"_\", frame_index, \".gz\");\n\n  gzFile file = gzopen(output_path.string().c_str(), \"wb\");\n  if (file == nullptr) {\n    LOG(FATAL) << \"Failed to open output file: \" << output_path.string();\n  }\n\n  size_t remaining = chunk.size() * sizeof(FrameType);\n  const char* data = reinterpret_cast<const char*>(chunk.data());\n  while (remaining > 0) {\n    const unsigned int to_write = static_cast<unsigned int>(\n        std::min<size_t>(remaining, std::numeric_limits<unsigned int>::max()));\n    const int written = gzwrite(file, data, to_write);\n    if (written == 0) {\n      int errnum = 0;\n      const char* error_message = gzerror(file, &errnum);\n      gzclose(file);\n      LOG(FATAL) << \"Failed to write chunk: \" << error_message;\n    }\n    data += written;\n    remaining -= static_cast<size_t>(written);\n  }\n\n  if (gzclose(file) != Z_OK) {\n    LOG(FATAL) << \"Failed to close output file: \" << output_path.string();\n  }\n\n  LOG(INFO) << \"Wrote matching chunk to \" << output_path.string();\n}\n\nvoid ProcessTar(const fs::path& tar_path, const fs::path& output_dir,\n                absl::Span<const uint64_t> expected_planes) {\n  std::unique_ptr<ChunkSource> source = std::make_unique<TarChunkSource>(\n      tar_path, ChunkSourceLoaderConfig::V6TrainingData);\n\n  const std::string base_name = tar_path.stem().string();\n  size_t written_count = 0;\n\n  for (size_t index = 0, total = source->GetChunkCount(); index < total;\n       ++index) {\n    const std::optional<std::vector<FrameType>> chunk =\n        source->GetChunkData(index);\n    if (!chunk) {\n      LOG(WARNING) << \"Skipping unreadable chunk \" << index << \" in \"\n                   << tar_path.string();\n      continue;\n    }\n\n    const std::optional<size_t> match =\n        FindMatchingFrameIndex(*chunk, expected_planes);\n    if (!match) continue;\n\n    WriteChunk(output_dir, base_name, index, *match, *chunk);\n    ++written_count;\n  }\n\n  LOG(INFO) << \"Finished processing \" << tar_path.string() << \": wrote \"\n            << written_count << \" chunk(s).\";\n}\n\n}  // namespace\n\nint main(int argc, char** argv) {\n  absl::InitializeLog();\n  absl::ParseCommandLine(argc, argv);\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  const fs::path input_dir(absl::GetFlag(FLAGS_input_dir));\n  const fs::path output_dir(absl::GetFlag(FLAGS_output_dir));\n  const std::string plane_values = absl::GetFlag(FLAGS_plane_values);\n\n  if (!fs::exists(input_dir) || !fs::is_directory(input_dir)) {\n    LOG(FATAL) << \"Input directory does not exist: \" << input_dir.string();\n  }\n\n  const std::vector<uint64_t> expected_planes = ParsePlaneValues(plane_values);\n\n  fs::create_directories(output_dir);\n\n  const std::vector<fs::path> tar_files = CollectTarFiles(input_dir);\n  const absl::Span<const uint64_t> expected_span(expected_planes);\n\n  std::vector<std::thread> workers;\n  workers.reserve(tar_files.size());\n\n  for (const auto& tar_path : tar_files) {\n    workers.emplace_back([tar_path, output_dir, expected_span]() {\n      LOG(INFO) << \"Processing tar file: \" << tar_path.string();\n      try {\n        ProcessTar(tar_path, output_dir, expected_span);\n      } catch (const std::exception& e) {\n        LOG(WARNING) << \"Failed to process tar file \" << tar_path << \": \"\n                     << e.what();\n      }\n    });\n  }\n\n  for (auto& worker : workers) {\n    worker.join();\n  }\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/tools/position_weight_stats_main.cc",
    "content": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/str_format.h>\n\n#include <algorithm>\n#include <cstdint>\n#include <cstdio>\n#include <filesystem>\n#include <iostream>\n#include <memory>\n#include <numeric>\n#include <optional>\n#include <string>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/chunk_source/tar_chunk_source.h\"\n#include \"loader/stages/position_sampling.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n#include \"utils/training_data_printer.h\"\n\nABSL_FLAG(std::string, input_dir, \"\", \"Directory to scan for .tar files.\");\nABSL_FLAG(float, q_weight, 6.0, \"Value for diff_focus_q_weight.\");\nABSL_FLAG(float, pol_scale, 3.5, \"Value for diff_focus_pol_scale.\");\n\nnamespace {\n\nnamespace fs = std::filesystem;\n\nusing ::lczero::training::ChunkSource;\nusing ::lczero::training::ChunkSourceLoaderConfig;\nusing ::lczero::training::ComputePositionSamplingWeight;\nusing ::lczero::training::FrameType;\nusing ::lczero::training::PositionSamplingConfig;\nusing ::lczero::training::PrintTrainingDataEntry;\nusing ::lczero::training::TarChunkSource;\n\nstruct WeightedPosition {\n  FrameType data;\n  float weight;\n};\n\nstd::vector<fs::path> CollectTarFiles(const fs::path& directory) {\n  std::vector<fs::path> files;\n  for (const auto& entry : fs::directory_iterator(directory)) {\n    const fs::path& path = entry.path();\n    if (entry.is_regular_file() && path.extension() == \".tar\") {\n      files.push_back(path);\n    }\n  }\n  absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) {\n    return lhs.filename() < rhs.filename();\n  });\n  return files;\n}\n\nstd::vector<float> CollectWeights(const fs::path& tar_path,\n                                  const PositionSamplingConfig& config,\n                                  WeightedPosition* max_weighted) {\n  std::vector<float> weights;\n  std::unique_ptr<ChunkSource> source = std::make_unique<TarChunkSource>(\n      tar_path, ChunkSourceLoaderConfig::V6TrainingData);\n\n  const size_t total = source->GetChunkCount();\n  for (size_t index = 0; index < total; ++index) {\n    if (index % 1000 == 0) {\n      LOG(INFO) << absl::StreamFormat(\"  Progress: %zu/%zu chunks (%.1f%%)\",\n                                      index, total, 100.0 * index / total);\n    }\n\n    const std::optional<std::vector<FrameType>> chunk =\n        source->GetChunkData(index);\n    if (!chunk) {\n      LOG(WARNING) << \"Skipping unreadable chunk \" << index << \" in \"\n                   << tar_path.string();\n      continue;\n    }\n\n    if (chunk->empty()) continue;\n\n    for (const auto& entry : *chunk) {\n      const float weight = ComputePositionSamplingWeight(entry, config);\n      weights.push_back(weight);\n      if (max_weighted && weight > max_weighted->weight) {\n        max_weighted->data = entry;\n        max_weighted->weight = weight;\n      }\n    }\n  }\n\n  return weights;\n}\n\nvoid PrintHistogram(const std::vector<float>& sorted_weights) {\n  if (sorted_weights.empty()) return;\n\n  constexpr int kBuckets = 50;\n  constexpr int kMaxWidth = 60;\n  const float min_val = sorted_weights.front();\n  const float max_val = sorted_weights.back();\n  const float range = max_val - min_val;\n\n  if (range == 0.0f) {\n    LOG(INFO) << \"\\nHistogram: All weights are identical (\" << min_val << \")\";\n    return;\n  }\n\n  std::vector<int> buckets(kBuckets, 0);\n  for (float weight : sorted_weights) {\n    int bucket = static_cast<int>((weight - min_val) / range * (kBuckets - 1));\n    bucket = std::clamp(bucket, 0, kBuckets - 1);\n    ++buckets[bucket];\n  }\n\n  const int max_count = *absl::c_max_element(buckets);\n  if (max_count == 0) return;\n\n  LOG(INFO) << \"\\nHistogram:\";\n  for (int bucket = 0; bucket < kBuckets; ++bucket) {\n    if (buckets[bucket] == 0) continue;\n\n    const float bucket_start = min_val + range * bucket / kBuckets;\n    const float bucket_end = min_val + range * (bucket + 1) / kBuckets;\n    const int width = (buckets[bucket] * kMaxWidth + max_count / 2) / max_count;\n\n    std::string bar;\n    for (int i = 0; i < width; ++i) bar += \"█\";\n\n    LOG(INFO) << absl::StreamFormat(\"[%.4f-%.4f) │%s (%d)\", bucket_start,\n                                    bucket_end, bar, buckets[bucket]);\n  }\n}\n\nvoid PrintPercentiles(const std::vector<float>& sorted_weights) {\n  if (sorted_weights.empty()) return;\n\n  LOG(INFO) << \"\\nPercentiles:\";\n  for (int p = 0; p <= 100; ++p) {\n    const size_t idx = (sorted_weights.size() - 1) * p / 100;\n    LOG(INFO) << absl::StreamFormat(\"  %3d%%: %.6f\", p, sorted_weights[idx]);\n  }\n}\n\nvoid PrintStatistics(const std::vector<float>& weights) {\n  if (weights.empty()) {\n    LOG(INFO) << \"No weights collected.\";\n    return;\n  }\n\n  const double sum = std::accumulate(weights.begin(), weights.end(), 0.0);\n  const double mean = sum / weights.size();\n\n  LOG(INFO) << \"\\nStatistics:\";\n  LOG(INFO) << \"  Total positions: \" << weights.size();\n  LOG(INFO) << absl::StreamFormat(\"  Mean: %.6f\", mean);\n}\n\n}  // namespace\n\nint main(int argc, char** argv) {\n  absl::InitializeLog();\n  absl::ParseCommandLine(argc, argv);\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  const std::string input_dir_str = absl::GetFlag(FLAGS_input_dir);\n  const float q_weight = absl::GetFlag(FLAGS_q_weight);\n  const float pol_scale = absl::GetFlag(FLAGS_pol_scale);\n\n  if (input_dir_str.empty()) {\n    LOG(FATAL) << \"--input_dir must be specified.\";\n  }\n\n  const fs::path input_dir(input_dir_str);\n  if (!fs::exists(input_dir) || !fs::is_directory(input_dir)) {\n    LOG(FATAL) << \"Input directory does not exist: \" << input_dir.string();\n  }\n\n  PositionSamplingConfig config;\n  config.set_diff_focus_q_weight(q_weight);\n  config.set_diff_focus_pol_scale(pol_scale);\n\n  const std::vector<fs::path> tar_files = CollectTarFiles(input_dir);\n  LOG(INFO) << \"Found \" << tar_files.size() << \" tar file(s).\";\n\n  WeightedPosition max_weighted = {{}, 0.0f};\n  std::vector<float> all_weights;\n  for (const auto& tar_path : tar_files) {\n    LOG(INFO) << \"Processing: \" << tar_path.string();\n    std::vector<float> weights =\n        CollectWeights(tar_path, config, &max_weighted);\n    all_weights.insert(all_weights.end(), weights.begin(), weights.end());\n    LOG(INFO) << \"  Collected \" << weights.size() << \" position(s).\";\n  }\n\n  absl::c_sort(all_weights);\n\n  PrintStatistics(all_weights);\n  PrintPercentiles(all_weights);\n  PrintHistogram(all_weights);\n\n  if (max_weighted.weight > 0.0f) {\n    const std::string header = absl::StrFormat(\n        \"\\nPosition with highest weight (%.6f):\", max_weighted.weight);\n    PrintTrainingDataEntry(max_weighted.data, header, 8, 4);\n  }\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/tools/rescore_chunk_main.cc",
    "content": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n\n#include <cstdint>\n#include <filesystem>\n#include <string>\n#include <vector>\n\n#include \"chess/board.h\"\n#include \"proto/data_loader_config.pb.h\"\n#include \"syzygy/syzygy.h\"\n#include \"trainingdata/reader.h\"\n#include \"trainingdata/rescorer.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n#include \"trainingdata/writer.h\"\n#include \"utils/exception.h\"\n\nABSL_FLAG(std::string, chunk_path, \"\",\n          \"Path to the chunk file (.gz) that should be rescored.\");\nABSL_FLAG(std::string, syzygy_paths, \"\",\n          \"Comma-separated list of Syzygy tablebase directories.\");\nABSL_FLAG(double, dist_temp, 1.0,\n          \"Policy temperature applied during rescoring.\");\nABSL_FLAG(double, dist_offset, 0.0, \"Policy offset applied during rescoring.\");\nABSL_FLAG(double, dtz_boost, 0.0,\n          \"DTZ boost applied during policy adjustments.\");\nABSL_FLAG(int, new_input_format, -1,\n          \"Optional conversion target for input format (-1 keeps original).\");\nABSL_FLAG(\n    double, deblunder_threshold, -1.0,\n    \"Threshold for policy deblundering adjustments (negative to disable).\");\nABSL_FLAG(\n    double, deblunder_width, -1.0,\n    \"Width controlling smoothing around threshold (negative to disable).\");\n\nnamespace {\n\nnamespace fs = std::filesystem;\nusing ::lczero::training::ChunkRescorerConfig;\n\nstd::vector<lczero::V6TrainingData> ReadChunkFrames(const fs::path& path) {\n  std::vector<lczero::V6TrainingData> frames;\n  lczero::TrainingDataReader reader(path.string());\n  lczero::V6TrainingData frame;\n  while (reader.ReadChunk(&frame)) {\n    frames.push_back(frame);\n  }\n  return frames;\n}\n\nvoid WriteChunkFrames(const fs::path& path,\n                      const std::vector<lczero::V6TrainingData>& frames) {\n  lczero::TrainingDataWriter writer(path.string());\n  for (const auto& frame : frames) {\n    writer.WriteChunk(frame);\n  }\n  writer.Finalize();\n}\n\nfs::path BuildOutputPath(const fs::path& input_path) {\n  fs::path directory = input_path.parent_path();\n  fs::path stem = input_path.stem();\n  fs::path filename = stem;\n  filename += \"_rescored.gz\";\n  return directory / filename;\n}\n\n}  // namespace\n\nint main(int argc, char** argv) {\n  absl::InitializeLog();\n  absl::ParseCommandLine(argc, argv);\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  const std::string chunk_path_flag = absl::GetFlag(FLAGS_chunk_path);\n  if (chunk_path_flag.empty()) {\n    LOG(FATAL) << \"--chunk_path flag is required.\";\n  }\n  const fs::path chunk_path(chunk_path_flag);\n\n  ChunkRescorerConfig config;\n\n  const std::string syzygy_paths_flag = absl::GetFlag(FLAGS_syzygy_paths);\n  if (!syzygy_paths_flag.empty()) {\n    config.set_syzygy_paths(syzygy_paths_flag);\n  }\n  config.set_dist_temp(static_cast<float>(absl::GetFlag(FLAGS_dist_temp)));\n  config.set_dist_offset(static_cast<float>(absl::GetFlag(FLAGS_dist_offset)));\n  config.set_dtz_boost(static_cast<float>(absl::GetFlag(FLAGS_dtz_boost)));\n  config.set_new_input_format(\n      static_cast<int32_t>(absl::GetFlag(FLAGS_new_input_format)));\n\n  const double deblunder_threshold_flag =\n      absl::GetFlag(FLAGS_deblunder_threshold);\n  const double deblunder_width_flag = absl::GetFlag(FLAGS_deblunder_width);\n  if (deblunder_threshold_flag >= 0.0 && deblunder_width_flag >= 0.0) {\n    config.set_deblunder_threshold(\n        static_cast<float>(deblunder_threshold_flag));\n    config.set_deblunder_width(static_cast<float>(deblunder_width_flag));\n  } else if (deblunder_threshold_flag >= 0.0 || deblunder_width_flag >= 0.0) {\n    LOG(FATAL) << \"Both --deblunder_threshold and --deblunder_width must be \"\n               << \"set to non-negative values together.\";\n  }\n\n  LOG(INFO) << \"Reading chunk from \" << chunk_path.string();\n  std::vector<lczero::V6TrainingData> frames;\n  try {\n    frames = ReadChunkFrames(chunk_path);\n  } catch (const lczero::Exception& exception) {\n    LOG(FATAL) << \"Failed to read chunk: \" << exception.what();\n  }\n  LOG(INFO) << \"Loaded \" << frames.size() << \" frame(s) from chunk.\";\n  if (frames.empty()) {\n    LOG(WARNING) << \"Chunk contains no frames; writing empty output.\";\n    try {\n      WriteChunkFrames(BuildOutputPath(chunk_path), frames);\n    } catch (const lczero::Exception& exception) {\n      LOG(FATAL) << \"Failed to write rescored chunk: \" << exception.what();\n    }\n    return 0;\n  }\n\n  lczero::InitializeMagicBitboards();\n\n  if (config.has_deblunder_threshold() && config.has_deblunder_width()) {\n    lczero::RescorerDeblunderSetup(config.deblunder_threshold(),\n                                   config.deblunder_width());\n  }\n\n  lczero::SyzygyTablebase tablebase;\n  if (!config.syzygy_paths().empty()) {\n    LOG(INFO) << \"Initializing Syzygy tablebases from '\"\n              << config.syzygy_paths() << \"'.\";\n    const std::string syzygy_paths(config.syzygy_paths());\n    if (!tablebase.init(syzygy_paths)) {\n      LOG(WARNING) << \"Failed to initialize Syzygy tablebases.\";\n    }\n  }\n\n  LOG(INFO) << \"Rescoring chunk with dist_temp=\" << config.dist_temp()\n            << \", dist_offset=\" << config.dist_offset()\n            << \", dtz_boost=\" << config.dtz_boost()\n            << \", new_input_format=\" << config.new_input_format() << \".\";\n\n  try {\n    frames = lczero::RescoreTrainingData(\n        std::move(frames), &tablebase, config.dist_temp(), config.dist_offset(),\n        config.dtz_boost(), config.new_input_format());\n  } catch (const lczero::Exception& exception) {\n    LOG(FATAL) << \"Failed to rescore chunk: \" << exception.what();\n  }\n\n  const fs::path output_path = BuildOutputPath(chunk_path);\n  LOG(INFO) << \"Writing rescored chunk to \" << output_path.string();\n  try {\n    WriteChunkFrames(output_path, frames);\n  } catch (const lczero::Exception& exception) {\n    LOG(FATAL) << \"Failed to write rescored chunk: \" << exception.what();\n  }\n  LOG(INFO) << \"Completed rescoring of chunk.\";\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/tools/result_distribution_main.cc",
    "content": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/string_view.h>\n#include <absl/synchronization/mutex.h>\n\n#include <algorithm>\n#include <atomic>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <cstring>\n#include <filesystem>\n#include <fstream>\n#include <iostream>\n#include <optional>\n#include <ostream>\n#include <string>\n#include <thread>\n#include <utility>\n#include <vector>\n\n#include \"loader/chunk_source/tar_chunk_source.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n\nABSL_FLAG(std::string, output_csv, \"\",\n          \"Destination CSV file. Writes to stdout if empty.\");\nABSL_FLAG(int, num_threads, 0,\n          \"Number of worker threads. Defaults to hardware concurrency.\");\n\nnamespace {\n\nnamespace fs = std::filesystem;\n\nusing ::lczero::training::ChunkSourceLoaderConfig;\nusing ::lczero::training::FrameType;\nusing ::lczero::training::TarChunkSource;\n\nenum class ChunkResult { kWin, kDraw, kLoss };\n\nstruct ResultCounts {\n  uint64_t wins = 0;\n  uint64_t draws = 0;\n  uint64_t losses = 0;\n};\n\nclass CsvWriter {\n public:\n  CsvWriter(std::ostream* output, absl::Mutex* mutex)\n      : output_(output), mutex_(mutex) {}\n\n  void Write(absl::string_view basename, const ResultCounts& counts) const {\n    absl::MutexLock lock(mutex_);\n    *output_ << basename << ',' << counts.wins << ',' << counts.draws << ','\n             << counts.losses << '\\n';\n    output_->flush();\n  }\n\n private:\n  std::ostream* output_;\n  absl::Mutex* mutex_;\n};\n\nstd::ostream& SelectOutput(const fs::path& output_path,\n                           std::ofstream& file_stream) {\n  if (output_path.empty()) return std::cout;\n\n  file_stream.open(output_path, std::ios::out | std::ios::trunc);\n  if (!file_stream) {\n    LOG(FATAL) << \"Failed to open output file: \" << output_path.string();\n  }\n  return file_stream;\n}\n\nconstexpr float kFloatTolerance = 1e-6f;\n\nstd::optional<ChunkResult> DetermineChunkResult(absl::string_view chunk_payload,\n                                                size_t chunk_index,\n                                                const fs::path& tar_path) {\n  if (chunk_payload.size() < sizeof(FrameType)) {\n    LOG(WARNING) << \"Chunk \" << chunk_index << \" in \" << tar_path.string()\n                 << \" is too small.\";\n    return std::nullopt;\n  }\n\n  FrameType frame;\n  std::memcpy(&frame, chunk_payload.data(), sizeof(frame));\n\n  if (std::fabs(frame.result_d - 1.0f) <= kFloatTolerance) {\n    return ChunkResult::kDraw;\n  }\n  if (std::fabs(frame.result_d) > kFloatTolerance) {\n    LOG(WARNING) << \"Chunk \" << chunk_index << \" in \" << tar_path.string()\n                 << \" has unexpected result_d=\" << frame.result_d << '.';\n    return std::nullopt;\n  }\n\n  const bool side_to_move = frame.side_to_move_or_enpassant != 0;\n  if (std::fabs(frame.result_q - 1.0f) <= kFloatTolerance) {\n    return side_to_move ? ChunkResult::kLoss : ChunkResult::kWin;\n  }\n  if (std::fabs(frame.result_q + 1.0f) <= kFloatTolerance) {\n    return side_to_move ? ChunkResult::kWin : ChunkResult::kLoss;\n  }\n\n  LOG(WARNING) << \"Chunk \" << chunk_index << \" in \" << tar_path.string()\n               << \" has unexpected result_q=\" << frame.result_q << '.';\n  return std::nullopt;\n}\n\nResultCounts CountResultsInTar(const fs::path& tar_path) {\n  ResultCounts counts;\n\n  TarChunkSource source(tar_path, ChunkSourceLoaderConfig::V6TrainingData);\n\n  const size_t chunk_count = source.GetChunkCount();\n  for (size_t index = 0; index < chunk_count; ++index) {\n    const std::optional<std::string> chunk =\n        source.GetChunkPrefix(index, sizeof(FrameType));\n    if (!chunk) {\n      LOG(WARNING) << \"Skipping unreadable chunk \" << index << \" in \"\n                   << tar_path.string();\n      continue;\n    }\n\n    const std::optional<ChunkResult> result =\n        DetermineChunkResult(*chunk, index, tar_path);\n    if (!result) continue;\n\n    switch (*result) {\n      case ChunkResult::kWin:\n        ++counts.wins;\n        break;\n      case ChunkResult::kDraw:\n        ++counts.draws;\n        break;\n      case ChunkResult::kLoss:\n        ++counts.losses;\n        break;\n    }\n  }\n\n  return counts;\n}\n\n}  // namespace\n\nint main(int argc, char** argv) {\n  absl::InitializeLog();\n  std::vector<char*> positional = absl::ParseCommandLine(argc, argv);\n  absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);\n\n  if (positional.size() <= 1) {\n    LOG(FATAL) << \"Provide at least one .tar file as a positional argument.\";\n  }\n\n  const fs::path output_path(absl::GetFlag(FLAGS_output_csv));\n  std::ofstream file_stream;\n  std::ostream& output = SelectOutput(output_path, file_stream);\n  absl::Mutex output_mutex;\n  const CsvWriter writer(&output, &output_mutex);\n\n  std::vector<fs::path> tar_files;\n  tar_files.reserve(positional.size() - 1);\n  for (size_t i = 1; i < positional.size(); ++i) {\n    tar_files.emplace_back(positional[i]);\n  }\n\n  const int num_threads_flag = absl::GetFlag(FLAGS_num_threads);\n  size_t worker_count = 0;\n  if (num_threads_flag > 0) {\n    worker_count = static_cast<size_t>(num_threads_flag);\n  } else {\n    const unsigned int hw_threads = std::thread::hardware_concurrency();\n    worker_count = hw_threads > 0 ? static_cast<size_t>(hw_threads) : 1;\n  }\n  worker_count = std::max<size_t>(1, std::min(worker_count, tar_files.size()));\n\n  std::atomic<size_t> next_index(0);\n  std::vector<std::thread> workers;\n  workers.reserve(worker_count);\n\n  for (size_t worker_id = 0; worker_id < worker_count; ++worker_id) {\n    workers.emplace_back([&tar_files, &next_index, &writer]() {\n      while (true) {\n        const size_t index = next_index.fetch_add(1, std::memory_order_relaxed);\n        if (index >= tar_files.size()) break;\n        const fs::path& tar_path = tar_files[index];\n        LOG(INFO) << \"Processing tar file: \" << tar_path.string();\n        try {\n          const ResultCounts counts = CountResultsInTar(tar_path);\n          writer.Write(tar_path.filename().string(), counts);\n        } catch (const std::exception& exception) {\n          LOG(WARNING) << \"Failed to process tar file \" << tar_path.string()\n                       << \": \" << exception.what();\n        }\n      }\n    });\n  }\n\n  for (auto& worker : workers) {\n    worker.join();\n  }\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/tools/startpos_policy_distribution_main.cc",
    "content": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h>\n#include <absl/log/log.h>\n#include <absl/strings/string_view.h>\n#include <absl/types/span.h>\n\n#include <array>\n#include <cstdint>\n#include <cstring>\n#include <filesystem>\n#include <fstream>\n#include <iostream>\n#include <memory>\n#include <optional>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"loader/chunk_source/tar_chunk_source.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n\nABSL_FLAG(std::string, input_dir, \".\", \"Directory to scan for .tar files.\");\nABSL_FLAG(std::string, output_csv, \"\",\n          \"Destination CSV file. Writes to stdout if empty.\");\n\nnamespace {\n\nnamespace fs = std::filesystem;\n\nusing ::lczero::training::ChunkSource;\nusing ::lczero::training::ChunkSourceLoaderConfig;\nusing ::lczero::training::FrameType;\nusing ::lczero::training::TarChunkSource;\n\nconstexpr std::array<uint64_t, 16> kStartPositionPlanes = {\n    0x000000000000ff00ull, 0x0000000000000042ull, 0x0000000000000024ull,\n    0x0000000000000081ull, 0x0000000000000010ull, 0x0000000000000008ull,\n    0x00ff000000000000ull, 0x4200000000000000ull, 0x2400000000000000ull,\n    0x8100000000000000ull, 0x1000000000000000ull, 0x0800000000000000ull,\n    0x0000000000000000ull, 0x0000000000000000ull, 0x0000000000000000ull,\n    0x0000000000000000ull};\n\nusing PolicyProbe = std::pair<int, absl::string_view>;\n\nconstexpr std::array<PolicyProbe, 20> kPolicyProbes = {\n    {{378, \"g2g4\"}, {346, \"f2f3\"}, {34, \"b1a3\"},  {161, \"g1h3\"},\n     {403, \"h2h4\"}, {351, \"f2f4\"}, {234, \"b2b4\"}, {207, \"a2a4\"},\n     {288, \"d2d3\"}, {204, \"a2a3\"}, {259, \"c2c3\"}, {36, \"b1c3\"},\n     {400, \"h2h3\"}, {230, \"b2b3\"}, {322, \"e2e4\"}, {317, \"e2e3\"},\n     {374, \"g2g3\"}, {264, \"c2c4\"}, {159, \"g1f3\"}, {293, \"d2d4\"}}};\n\nbool MatchesStartPosition(const FrameType& data) {\n  return absl::c_equal(\n      kStartPositionPlanes,\n      absl::Span<const uint64_t>(data.planes, kStartPositionPlanes.size()));\n}\n\nstd::vector<fs::path> CollectTarFiles(const fs::path& directory) {\n  std::vector<fs::path> files;\n  for (const auto& entry : fs::directory_iterator(directory)) {\n    const fs::path& path = entry.path();\n    if (entry.is_regular_file() && path.extension() == \".tar\") {\n      files.push_back(path);\n    }\n  }\n  absl::c_sort(files, [](const fs::path& lhs, const fs::path& rhs) {\n    return lhs.filename() < rhs.filename();\n  });\n  return files;\n}\n\nvoid WriteHeader(std::ostream& output) {\n  output << \"file,index\";\n  for (const auto& probe : kPolicyProbes) output << ',' << probe.second;\n  output << '\\n';\n}\n\nvoid WriteRow(std::ostream& output, absl::string_view sort_key, size_t index,\n              const FrameType& data) {\n  output << sort_key << ',' << index;\n  for (const auto& probe : kPolicyProbes) {\n    output << ',' << data.probabilities[probe.first];\n  }\n  output << '\\n';\n}\n\nvoid ProcessTarFile(const fs::path& tar_path, std::ostream& output) {\n  std::unique_ptr<ChunkSource> source = std::make_unique<TarChunkSource>(\n      tar_path, ChunkSourceLoaderConfig::V6TrainingData);\n  const std::string sort_key = source->GetChunkSortKey();\n\n  for (size_t i = 0, count = source->GetChunkCount(); i < count; ++i) {\n    const std::optional<std::vector<FrameType>> chunk = source->GetChunkData(i);\n    if (!chunk || chunk->empty()) continue;\n\n    const FrameType& entry = chunk->front();\n    if (!MatchesStartPosition(entry)) continue;\n\n    WriteRow(output, sort_key, i, entry);\n  }\n}\n\nstd::ostream& SelectOutput(const fs::path& output_path,\n                           std::ofstream& file_stream) {\n  if (output_path.empty()) return std::cout;\n  file_stream.open(output_path, std::ios::out | std::ios::trunc);\n  if (!file_stream) {\n    LOG(FATAL) << \"Failed to open output file: \" << output_path.string();\n  }\n  return file_stream;\n}\n\n}  // namespace\n\nint main(int argc, char** argv) {\n  absl::InitializeLog();\n  absl::ParseCommandLine(argc, argv);\n\n  const fs::path input_dir(absl::GetFlag(FLAGS_input_dir));\n  const fs::path output_path(absl::GetFlag(FLAGS_output_csv));\n\n  if (!fs::is_directory(input_dir)) {\n    LOG(FATAL) << \"Input directory does not exist: \" << input_dir.string();\n  }\n\n  std::ofstream file_stream;\n  std::ostream& output = SelectOutput(output_path, file_stream);\n\n  WriteHeader(output);\n\n  for (const auto& tar_path : CollectTarFiles(input_dir)) {\n    LOG(INFO) << \"Processing tar file: \" << tar_path.string();\n    try {\n      ProcessTarFile(tar_path, output);\n    } catch (const std::exception& e) {\n      LOG(WARNING) << \"Failed to process tar file \" << tar_path << \": \"\n                   << e.what();\n    }\n  }\n\n  return 0;\n}\n"
  },
  {
    "path": "csrc/utils/gz.cc",
    "content": "#include \"utils/gz.h\"\n\n#include <absl/log/log.h>\n#include <zlib.h>\n\n#include <array>\n#include <stdexcept>\n\nnamespace lczero {\nnamespace training {\n\nstd::string GunzipBuffer(std::string_view buffer) {\n  z_stream strm = {};\n  int ret = inflateInit2(&strm, 16 + MAX_WBITS);\n  if (ret != Z_OK) {\n    throw GunzipError(\"Failed to initialize zlib inflate\");\n  }\n\n  strm.avail_in = buffer.size();\n  strm.next_in = reinterpret_cast<Bytef*>(const_cast<char*>(buffer.data()));\n\n  constexpr size_t kChunkSize = 16384;\n  std::string output;\n  std::array<char, kChunkSize> temp_buffer;\n\n  do {\n    strm.avail_out = kChunkSize;\n    strm.next_out = reinterpret_cast<Bytef*>(temp_buffer.data());\n\n    ret = inflate(&strm, Z_NO_FLUSH);\n    if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT || ret == Z_DATA_ERROR ||\n        ret == Z_MEM_ERROR) {\n      inflateEnd(&strm);\n      throw GunzipError(\"zlib inflate error\");\n    }\n\n    size_t bytes_written = kChunkSize - strm.avail_out;\n    output.append(temp_buffer.begin(), temp_buffer.begin() + bytes_written);\n  } while (strm.avail_out == 0);\n\n  inflateEnd(&strm);\n\n  if (ret != Z_STREAM_END) {\n    throw GunzipError(\"Incomplete gzip decompression\");\n  }\n\n  return output;\n}\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/gz.h",
    "content": "#pragma once\n\n#include <span>\n#include <stdexcept>\n#include <string>\n\nnamespace lczero {\nnamespace training {\n\nclass GunzipError : public std::runtime_error {\n public:\n  using std::runtime_error::runtime_error;\n};\n\nstd::string GunzipBuffer(std::string_view buffer);\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/exponential_aggregator.h",
    "content": "#pragma once\n\n#include <absl/strings/str_cat.h>\n#include <absl/synchronization/mutex.h>\n\n#include <chrono>\n#include <cmath>\n#include <functional>\n#include <optional>\n#include <string>\n#include <tuple>\n#include <type_traits>\n#include <vector>\n\nnamespace lczero {\n\n// 1 second period is exact. Other periods are powers of two.\nenum class TimePeriod {\n  kEmpty = -128,\n  k1Millisecond = -10,\n  k2Milliseconds,\n  k4Milliseconds,\n  k8Milliseconds,\n  k16Milliseconds,\n  k31Milliseconds,\n  k63Milliseconds,\n  k125Milliseconds,\n  k250Milliseconds,\n  k500Milliseconds,\n  k1Second /* = 0 */,\n  k2Seconds,\n  k4Seconds,\n  k8Seconds,\n  k16Seconds,\n  k32Seconds,\n  k1Minute,\n  k2Minutes,\n  k4Minutes,\n  k9Minutes,\n  k17Minutes,\n  k36Minutes,\n  k1Hour,\n  k2Hours,\n  k5Hours,\n  k9Hours,\n  k18Hours,\n  k36Hours,\n  k3Days,\n  k6Days,\n  k12Days,\n  k24Days,\n  k49Days,\n  k97Days,\n  k194Days,\n  k388Days,\n  k777Days,\n  k2Years,\n  k4Years,\n  k9Years,\n  k17Years, /* = 29 */\n  kAllTime = 127\n};\n\n// ExponentialAggregator metrics over exponentially increasing time periods.\n//\n// The template parameter `Metric` can be used in two ways:\n// 1. Function-based: Pass Clear and UpdateFrom functions to constructor.\n//    This is the new approach for protobuf-based metrics.\n// 2. Method-based: The metric type has Reset() and MergeFrom() methods.\n//    This maintains backward compatibility with existing C++ struct metrics.\n//\n// For method-based metrics, the type must satisfy:\n// * It must have a `Reset()` method that clears its state.\n// * It must have a `MergeFrom(const Metric& other)` method to merge another\n//   metric into itself (used for bucket-to-bucket merging and live ingestion).\n// * It must behave like a monoid (actually, unital magma is sufficient):\n//   * Merging with a default-constructed (empty) metric is a no-op.\n//   * The `MergeFrom` operation must be associative (actually, not really;\n//   currently we always merge old to new). It does not need to be\n//     commutative.\n//   * Having a `std::swap(Metric&, Metric&)` method is also beneficial.\ntemplate <typename Metric, TimePeriod Resolution = TimePeriod::k16Milliseconds>\nclass ExponentialAggregator {\n public:\n  using Duration = std::chrono::nanoseconds;\n  using Clock = std::chrono::steady_clock;\n  using ClearFn = std::function<void(Metric&)>;\n  using UpdateFromFn = std::function<void(Metric&, const Metric&)>;\n  static constexpr TimePeriod kResolution = Resolution;\n\n  // Resets the aggregator, clearing all buckets and pending metrics.\n  void Reset(Clock::time_point now = Clock::now());\n\n  // Ingests the passed metric into the pending bucket using MergeFrom + Reset.\n  void RecordMetrics(Metric&& metric);\n\n  // Returns the latest completed metrics bucket for the given time period and\n  // duration since that period finished last time. If now is nullopt, it\n  // excludes the time since the last metrics flush.\n  std::pair<Metric, Duration> GetBucketMetrics(\n      TimePeriod period,\n      std::optional<Clock::time_point> now = std::nullopt) const;\n\n  // Returns the current metrics that have been collected for at least the\n  // specified duration. Returns the metrics and the duration since the\n  // beginning of the covered period. If `now` is nullopt, it excludes metrics\n  // and the time since the last metrics flush.\n  std::pair<Metric, Duration> GetAggregateEndingNow(\n      Duration duration,\n      std::optional<Clock::time_point> now = Clock::now()) const;\n\n  // Flushes the current pending bucket into the exponential metrics and\n  // advances time by the elapsed duration, potentially processing multiple\n  // ticks. Returns the largest time period that was updated by this advance\n  // (all smaller periods are also updated).\n  //\n  // Note: Advance() should typically be called more frequently than the tick\n  // frequency. When called less frequently (advancing multiple ticks at once),\n  // live statistics go into the first bucket and subsequent buckets are padded\n  // with empty metrics.\n  TimePeriod Advance(Clock::time_point now = Clock::now());\n\n  constexpr Duration GetResolution() const { return kPeriodDuration; }\n\n  // Primary constructor for new protobuf-based metrics.\n  // Takes two free functions that define the metric's behavior.\n  ExponentialAggregator(ClearFn clear_fn, UpdateFromFn update_from_fn);\n\n  // Default constructor for backward compatibility with old C++ metrics.\n  // This will only compile if `Metric` has Reset() and MergeFrom() methods.\n  ExponentialAggregator();\n\n private:\n  // Constexpr power of 2 using bit shifts\n  static constexpr double constexpr_pow2(int exp) {\n    if (exp >= 0) {\n      return static_cast<double>(1LL << exp);\n    } else {\n      return 1.0 / static_cast<double>(1LL << (-exp));\n    }\n  }\n\n  static constexpr Duration kPeriodDuration =\n      std::chrono::duration_cast<Duration>(std::chrono::duration<double>(\n          constexpr_pow2(static_cast<int>(Resolution))));\n\n  static size_t GetBucketIndex(TimePeriod period) {\n    return static_cast<size_t>(period) - static_cast<size_t>(Resolution);\n  }\n\n  // The aggregation strategy is analogous to a binary counter. `tick_count_`\n  // represents the counter's value, and the `buckets_` array corresponds to its\n  // bits, each covering an exponentially larger time period.\n  //\n  // Advancing time increments `tick_count_`. When a bit flips from 1 to 0,\n  // its bucket's metric is merged (the \"carry\") into the next higher bucket.\n  //\n  // Note that buckets are never empty. A bucket whose corresponding bit in\n  // `tick_count_` is '0' simply holds the last complete metric for its time\n  // period. This ensures that a valid, historical metric is always available\n  // for the bucket query.\n\n  ClearFn clear_fn_;\n  UpdateFromFn update_from_fn_;\n\n  mutable absl::Mutex mutex_;\n  size_t tick_count_ ABSL_GUARDED_BY(mutex_);\n  // Buckets for each time period, starting from Resolution.\n  std::vector<Metric> buckets_ ABSL_GUARDED_BY(mutex_);\n  Clock::time_point last_tick_time_ ABSL_GUARDED_BY(mutex_);\n\n  mutable absl::Mutex pending_bucket_mutex_ ABSL_ACQUIRED_AFTER(mutex_);\n  Metric pending_bucket_ ABSL_GUARDED_BY(pending_bucket_mutex_);\n};\n\ntemplate <typename Metric, TimePeriod Resolution>\nExponentialAggregator<Metric, Resolution>::ExponentialAggregator(\n    ClearFn clear_fn, UpdateFromFn update_from_fn)\n    : clear_fn_(std::move(clear_fn)),\n      update_from_fn_(std::move(update_from_fn)),\n      tick_count_(0),\n      last_tick_time_(Clock::now()) {}\n\ntemplate <typename Metric, TimePeriod Resolution>\nExponentialAggregator<Metric, Resolution>::ExponentialAggregator()\n    : tick_count_(0), last_tick_time_(Clock::now()) {\n  // For backward compatibility, use metric methods if available\n  if constexpr (requires(Metric& m, const Metric& other) {\n                  m.Reset();\n                  m.MergeFrom(other);\n                }) {\n    clear_fn_ = [](Metric& m) { m.Reset(); };\n    update_from_fn_ = [](Metric& dest, const Metric& src) {\n      dest.MergeFrom(src);\n    };\n  } else {\n    static_assert(sizeof(Metric) == 0,\n                  \"Metric type must have Reset() and MergeFrom() methods or \"\n                  \"use function-based constructor\");\n  }\n}\n\ntemplate <typename Metric, TimePeriod Resolution>\nvoid ExponentialAggregator<Metric, Resolution>::Reset(\n    std::chrono::steady_clock::time_point now) {\n  absl::MutexLock lock(&mutex_);\n  tick_count_ = 0;\n  buckets_.clear();\n  last_tick_time_ = now;\n\n  absl::MutexLock pending_bucket_lock(&pending_bucket_mutex_);\n  clear_fn_(pending_bucket_);\n}\n\ntemplate <typename Metric, TimePeriod Resolution>\nvoid ExponentialAggregator<Metric, Resolution>::RecordMetrics(Metric&& metric) {\n  absl::MutexLock lock(&pending_bucket_mutex_);\n  update_from_fn_(pending_bucket_, metric);\n  clear_fn_(metric);\n}\n\ntemplate <typename Metric, TimePeriod Resolution>\nauto ExponentialAggregator<Metric, Resolution>::GetBucketMetrics(\n    TimePeriod period, std::optional<Clock::time_point> now) const\n    -> std::pair<Metric, Duration> {\n  absl::MutexLock lock(&mutex_);\n  const size_t index = GetBucketIndex(period);\n  const Duration duration_since_update =\n      kPeriodDuration * (tick_count_ % (1ULL << index)) +\n      (now.has_value() ? Duration(*now - last_tick_time_) : Duration::zero());\n  if (index >= buckets_.size()) return {Metric(), duration_since_update};\n  return {buckets_[index], duration_since_update};\n}\n\ntemplate <typename Metric, TimePeriod Resolution>\nauto ExponentialAggregator<Metric, Resolution>::GetAggregateEndingNow(\n    Duration duration, std::optional<Clock::time_point> now) const\n    -> std::pair<Metric, Duration> {\n  Duration result_duration = Duration::zero();\n  Metric result;\n\n  {\n    absl::MutexLock lock(&mutex_);\n    if (now.has_value()) {\n      // If we'll use pending bucket, remove its duration from the request.\n      // The actual bucket we'll merge in the end as we have to merge newer\n      // into older buckets.\n      Duration duration_since_update = *now - last_tick_time_;\n      duration -= duration_since_update;\n      result_duration += duration_since_update;\n    }\n\n    if (duration > Duration::zero()) {\n      // Convert the input `duration` into `num_ticks` (the number of base\n      //    time periods), rounding up.\n      const auto div = std::div(duration.count(), kPeriodDuration.count());\n      const size_t num_ticks = div.quot + bool(div.rem);\n\n      // To cover the remaining `duration`, we select the minimal set of active\n      // historical buckets (where the corresponding bit in `tick_count_` is 1)\n      // that, when combined, meet or exceed the target duration.\n      //\n      // 1. First we determine the bit width of `num_ticks` (e.g., for 13\n      //    (1101b), the width is 4). Create a candidate set of ticks by\n      //    masking `tick_count_` to this width. If this masked value is >=\n      //    `num_ticks`, the corresponding set of buckets is sufficient.\n      // 2. If the masked value from step 2 is insufficient, we include one\n      //    additional bucket. It doesn't matter if it's active or not.\n      size_t masked_ticks = [&]() {\n        const auto bit_width = std::bit_width(num_ticks);\n        const size_t mask = ~((~size_t{0}) << bit_width);\n        const size_t candidate_ticks = tick_count_ & mask;\n        if (candidate_ticks >= num_ticks) return candidate_ticks;\n        // One additional tick is needed.\n        return candidate_ticks | (size_t{1} << bit_width);\n      }();\n\n      while (masked_ticks) {\n        // Start merging from the highest bit (older bucket) to the lowest.\n        size_t idx = std::bit_width(masked_ticks) - 1;\n        masked_ticks &= ~(1ULL << idx);\n        if (idx < buckets_.size()) {\n          update_from_fn_(result, buckets_[idx]);\n          result_duration += kPeriodDuration * (1ULL << idx);\n        }\n      }\n    }\n  }\n\n  if (now.has_value()) {\n    // The pending bucket is merged last, as we have to merge newer into older\n    // buckets.\n    absl::MutexLock lock(&pending_bucket_mutex_);\n    update_from_fn_(result, pending_bucket_);\n  }\n\n  return {result, result_duration};\n}\n\ntemplate <typename Metric, TimePeriod Resolution>\nauto ExponentialAggregator<Metric, Resolution>::Advance(Clock::time_point now)\n    -> TimePeriod {\n  absl::MutexLock lock(&mutex_);\n  const int num_ticks_to_advance = (now - last_tick_time_) / kPeriodDuration;\n  if (num_ticks_to_advance <= 0) return TimePeriod::kEmpty;\n\n  last_tick_time_ += num_ticks_to_advance * kPeriodDuration;\n  Metric live_carry;\n  {\n    // What was pending, now becomes carry. Pending bucket is cleared.\n    absl::MutexLock pending_bucket_lock(&pending_bucket_mutex_);\n    live_carry = std::move(pending_bucket_);\n    clear_fn_(pending_bucket_);\n  }\n\n  const size_t initial_tick_count = tick_count_;\n\n  auto one_tick = [&](Metric& carry) {\n    ++tick_count_;\n\n    for (size_t i = 0;; ++i) {\n      const uint64_t interval_size = 1ULL << i;\n      if ((tick_count_ % interval_size) != 0) break;\n      while (i >= buckets_.size()) buckets_.emplace_back();\n      // We always merge new into the old, so we swap the carry first, and then\n      // merge into it.\n      std::swap(carry, buckets_[i]);\n      update_from_fn_(carry, buckets_[i]);\n    }\n  };\n\n  // Carry the pending bucket into the first tick.\n  one_tick(live_carry);\n  // Then, if more than one tick is requested, we carry the empty bucket\n  // through the remaining ticks.\n  for (int i = 1; i < num_ticks_to_advance; ++i) {\n    Metric empty_carry;\n    one_tick(empty_carry);\n  }\n\n  // Largest time period is the highest bit that was flipped in the process.\n  // To find it, we XOR the initial tick count with the current one, and\n  // find the highest bit.\n  return static_cast<TimePeriod>(\n      std::bit_width(initial_tick_count ^ tick_count_) - 1 +\n      static_cast<int>(Resolution));\n}\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/group.h",
    "content": "#pragma once\n\n#include <absl/strings/str_cat.h>\n\n#include \"utils/metrics/printer.h\"\n\nnamespace lczero {\n\n// Metric is a struct that implements the following interface:\n// - void Reset();  // Resets the metric to its initial state.\n// - void MergeFrom(const Metric& other);  // Merges another metric into this\n// one. Note that the incoming always happens later in time, so if e.g. merge\n// keeps the latest value, it should update the current value with the incoming\n// one. Used for bucket-to-bucket merging and live data ingestion.\n// - (optional) std::string_view name() const;\n// - (optional) std::string ToString() const; // If provided, returns a string\n// representation of the metric.\n\n// Group several metric types together. This allows us to have a single\n// `MetricGroup` that contains multiple different metrics.\n\ntemplate <typename... StatRecords>\nclass MetricGroup {\n public:\n  MetricGroup() = default;\n\n  // Calls reset on all stats.\n  void Reset();\n\n  // Merges each individual stat from `other` into this group.\n  void MergeFrom(const MetricGroup<StatRecords...>& other);\n\n  // Merges a single stat from `other` into this group.\n  template <typename T>\n  void MergeFrom(const T& other);\n\n  // Gets a const reference to a specific stat record.\n  template <typename T>\n  const T& Get() const;\n\n  // Gets a mutable pointer to a specific stat record.\n  template <typename T>\n  T* GetMutable();\n\n  // Calls MetricPrinter for each stat in the group.\n  void Print(MetricPrinter& printer) const;\n\n private:\n  std::tuple<StatRecords...> stats_;\n};\n\ntemplate <typename... StatRecords>\nvoid MetricGroup<StatRecords...>::Reset() {\n  (std::get<StatRecords>(stats_).Reset(), ...);\n}\n\ntemplate <typename... StatRecords>\nvoid MetricGroup<StatRecords...>::MergeFrom(\n    const MetricGroup<StatRecords...>& other) {\n  (std::get<StatRecords>(stats_).MergeFrom(std::get<StatRecords>(other.stats_)),\n   ...);\n}\n\ntemplate <typename... StatRecords>\ntemplate <typename T>\nvoid MetricGroup<StatRecords...>::MergeFrom(const T& other) {\n  static_assert((std::is_same_v<T, StatRecords> || ...),\n                \"Type T must be one of the Stats types\");\n  std::get<T>(stats_).MergeFrom(other);\n}\n\ntemplate <typename... StatRecords>\ntemplate <typename T>\nconst T& MetricGroup<StatRecords...>::Get() const {\n  static_assert((std::is_same_v<T, StatRecords> || ...),\n                \"Type T must be one of the Stats types\");\n  return std::get<T>(stats_);\n}\n\ntemplate <typename... StatRecords>\ntemplate <typename T>\nT* MetricGroup<StatRecords...>::GetMutable() {\n  static_assert((std::is_same_v<T, StatRecords> || ...),\n                \"Type T must be one of the Stats types\");\n  return &std::get<T>(stats_);\n}\n\ntemplate <typename... StatRecords>\nvoid MetricGroup<StatRecords...>::Print(MetricPrinter& printer) const {\n  (\n      [&](const auto& stat) {\n        if constexpr (requires { stat.Print(printer); }) {\n          stat.Print(printer);\n        }\n      }(std::get<StatRecords>(stats_)),\n      ...);\n}\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/load_metric.h",
    "content": "#pragma once\n\n#include <absl/base/thread_annotations.h>\n#include <absl/synchronization/mutex.h>\n\n#include <chrono>\n#include <optional>\n\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\n\n// ABOUTME: Helper class to manage timing logic for LoadMetricProto.\n// ABOUTME: Tracks active periods and flushes accumulated time to the protobuf\n// metric.\nclass LoadMetricUpdater {\n public:\n  using Clock = std::chrono::steady_clock;\n  using Duration = std::chrono::duration<double>;\n\n  explicit LoadMetricUpdater(Clock::time_point initial_time = Clock::now())\n      : last_flush_time_(initial_time), is_load_active_(true) {}\n\n  // Starts tracking load from the given time point.\n  // Returns true if load was previously stopped (successful start).\n  bool LoadStart(Clock::time_point now = Clock::now()) {\n    absl::MutexLock lock(&mutex_);\n    FlushInternal(now);\n    bool was_stopped = !is_load_active_;\n    is_load_active_ = true;\n    return was_stopped;\n  }\n\n  // Stops tracking load at the given time point.\n  // Returns true if load was previously active (successful stop).\n  bool LoadStop(Clock::time_point now = Clock::now()) {\n    absl::MutexLock lock(&mutex_);\n    FlushInternal(now);\n    bool was_active = is_load_active_;\n    is_load_active_ = false;\n    return was_active;\n  }\n\n  // Flushes any uncounted load time into the metric.\n  void Flush(Clock::time_point now = Clock::now()) {\n    absl::MutexLock lock(&mutex_);\n    FlushInternal(now);\n  }\n\n  // Flushes metrics and returns a copy, resetting the internal metric.\n  LoadMetricProto FlushMetrics(Clock::time_point now = Clock::now()) {\n    absl::MutexLock lock(&mutex_);\n    FlushInternal(now);\n    LoadMetricProto result = metric_;\n    metric_.Clear();\n    return result;\n  }\n\n private:\n  // Flushes any uncounted load time into the metric (assumes mutex held).\n  void FlushInternal(Clock::time_point now)\n      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n    Duration elapsed = now - last_flush_time_;\n    double elapsed_seconds = elapsed.count();\n\n    metric_.set_total_seconds(metric_.total_seconds() + elapsed_seconds);\n    if (is_load_active_) {\n      metric_.set_load_seconds(metric_.load_seconds() + elapsed_seconds);\n    }\n\n    last_flush_time_ = now;\n  }\n\n  mutable absl::Mutex mutex_;\n  LoadMetricProto metric_ ABSL_GUARDED_BY(mutex_);\n  Clock::time_point last_flush_time_ ABSL_GUARDED_BY(mutex_);\n  bool is_load_active_ ABSL_GUARDED_BY(mutex_);\n};\n\n// UpdateFrom function for LoadMetricProto - simple additive behavior\ninline void UpdateFrom(LoadMetricProto& dest, const LoadMetricProto& src) {\n  if (src.has_name()) dest.set_name(src.name());\n  dest.set_load_seconds(dest.load_seconds() + src.load_seconds());\n  dest.set_total_seconds(dest.total_seconds() + src.total_seconds());\n}\n\n// RAII class to temporarily pause load tracking.\nclass LoadMetricPauser {\n public:\n  explicit LoadMetricPauser(LoadMetricUpdater& updater) : updater_(updater) {\n    successfully_paused_ = updater_.LoadStop();\n  }\n\n  ~LoadMetricPauser() {\n    if (successfully_paused_ && should_resume_) {\n      updater_.LoadStart();\n    }\n  }\n\n  // Prevents the pauser from resuming load tracking in destructor.\n  void DoNotResume() { should_resume_ = false; }\n\n  LoadMetricPauser(const LoadMetricPauser&) = delete;\n  LoadMetricPauser& operator=(const LoadMetricPauser&) = delete;\n  LoadMetricPauser(LoadMetricPauser&&) = delete;\n  LoadMetricPauser& operator=(LoadMetricPauser&&) = delete;\n\n private:\n  LoadMetricUpdater& updater_;\n  bool successfully_paused_;\n  bool should_resume_ = true;\n};\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/load_metric_test.cc",
    "content": "#include \"utils/metrics/load_metric.h\"\n\n#include <gtest/gtest.h>\n\n#include <chrono>\n#include <memory>\n\n#include \"proto/training_metrics.pb.h\"\n#include \"utils/metrics/exponential_aggregator.h\"\n\nnamespace lczero {\nnamespace training {\n\nclass LoadMetricTest : public ::testing::Test {\n protected:\n  using Clock = LoadMetricUpdater::Clock;\n\n  void SetUp() override { start_time_ = Clock::now(); }\n\n  Clock::time_point start_time_;\n};\n\nTEST_F(LoadMetricTest, BasicLoadMetricProto) {\n  LoadMetricProto metric;\n  EXPECT_EQ(metric.load_seconds(), 0.0);\n  EXPECT_EQ(metric.total_seconds(), 0.0);\n\n  // Test UpdateFrom with LoadMetricUpdater\n  LoadMetricUpdater other_updater(start_time_);\n  other_updater.LoadStart(start_time_);\n  other_updater.LoadStop(start_time_ + std::chrono::milliseconds(500));\n  LoadMetricProto other =\n      other_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(500));\n  UpdateFrom(metric, other);\n  EXPECT_NEAR(metric.load_seconds(), 0.5, 1e-6);\n  EXPECT_NEAR(metric.total_seconds(), 0.5, 1e-6);\n\n  // Test Clear (used to be Reset)\n  metric.Clear();\n  EXPECT_EQ(metric.load_seconds(), 0.0);\n  EXPECT_EQ(metric.total_seconds(), 0.0);\n}\n\nTEST_F(LoadMetricTest, LoadMetricUpdaterBasic) {\n  LoadMetricUpdater updater(start_time_);\n  auto now = start_time_;\n\n  // Start load and verify initial state\n  updater.LoadStart(now);\n  LoadMetricProto metric = updater.FlushMetrics(now);\n  EXPECT_EQ(metric.load_seconds(), 0.0);\n  EXPECT_EQ(metric.total_seconds(), 0.0);\n\n  // Advance time and stop load\n  now += std::chrono::milliseconds(100);\n  updater.LoadStop(now);\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6);\n\n  // Wait idle time, then start again\n  now += std::chrono::milliseconds(50);\n  updater.LoadStart(now);\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.0, 1e-6);    // Reset after flush\n  EXPECT_NEAR(metric.total_seconds(), 0.05, 1e-6);  // Only idle time\n\n  now += std::chrono::milliseconds(200);\n  updater.LoadStop(now);\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.2, 1e-6);   // 0.2 load\n  EXPECT_NEAR(metric.total_seconds(), 0.2, 1e-6);  // 0.2 total\n}\n\nTEST_F(LoadMetricTest, LoadMetricUpdaterFlush) {\n  LoadMetricUpdater updater(start_time_);\n  auto now = start_time_;\n\n  // Start load\n  updater.LoadStart(now);\n  now += std::chrono::milliseconds(100);\n\n  // Flush should update the internal metric\n  updater.Flush(now);\n  LoadMetricProto metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6);\n\n  // Continue loading\n  now += std::chrono::milliseconds(50);\n  updater.LoadStop(now);\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.05, 1e-6);   // Only new load time\n  EXPECT_NEAR(metric.total_seconds(), 0.05, 1e-6);  // Only new total time\n}\n\nTEST_F(LoadMetricTest, LoadMetricProtoMerging) {\n  LoadMetricUpdater updater1(start_time_);\n  LoadMetricUpdater updater2(start_time_);\n  auto now = start_time_;\n\n  // Create load in updater1\n  updater1.LoadStart(now);\n  updater1.LoadStop(now + std::chrono::milliseconds(100));\n  LoadMetricProto metric1 =\n      updater1.FlushMetrics(now + std::chrono::milliseconds(100));\n  EXPECT_NEAR(metric1.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(metric1.total_seconds(), 0.1, 1e-6);\n\n  // Create load in updater2\n  updater2.LoadStart(now);\n  updater2.LoadStop(now + std::chrono::milliseconds(100));\n  LoadMetricProto metric2 =\n      updater2.FlushMetrics(now + std::chrono::milliseconds(100));\n  EXPECT_NEAR(metric2.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(metric2.total_seconds(), 0.1, 1e-6);\n\n  // Merge\n  UpdateFrom(metric1, metric2);\n  EXPECT_NEAR(metric1.load_seconds(), 0.2, 1e-6);\n  EXPECT_NEAR(metric1.total_seconds(), 0.2, 1e-6);\n  EXPECT_NEAR(metric2.load_seconds(), 0.1, 1e-6);   // Source unchanged\n  EXPECT_NEAR(metric2.total_seconds(), 0.1, 1e-6);  // Source unchanged\n}\n\nTEST_F(LoadMetricTest, LoadMetricProtoMoveSemantics) {\n  // Test that LoadMetricProto move semantics work correctly\n  LoadMetricUpdater source_updater(start_time_);\n  source_updater.LoadStart(start_time_);\n  source_updater.LoadStop(start_time_ + std::chrono::milliseconds(100));\n  LoadMetricProto source =\n      source_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(100));\n  EXPECT_NEAR(source.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(source.total_seconds(), 0.1, 1e-6);\n\n  // Test move construction\n  LoadMetricProto moved_constructed(std::move(source));\n  EXPECT_NEAR(moved_constructed.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(moved_constructed.total_seconds(), 0.1, 1e-6);\n\n  // Test move assignment\n  LoadMetricProto move_assigned;\n  LoadMetricUpdater another_updater(start_time_);\n  another_updater.LoadStart(start_time_);\n  another_updater.LoadStop(start_time_ + std::chrono::milliseconds(50));\n  LoadMetricProto another_source =\n      another_updater.FlushMetrics(start_time_ + std::chrono::milliseconds(50));\n  EXPECT_NEAR(another_source.load_seconds(), 0.05, 1e-6);\n  EXPECT_NEAR(another_source.total_seconds(), 0.05, 1e-6);\n\n  move_assigned = std::move(another_source);\n  EXPECT_NEAR(move_assigned.load_seconds(), 0.05, 1e-6);\n  EXPECT_NEAR(move_assigned.total_seconds(), 0.05, 1e-6);\n\n  // Test UpdateFrom\n  LoadMetricProto dest;\n  UpdateFrom(dest, moved_constructed);\n  EXPECT_NEAR(dest.load_seconds(), 0.1, 1e-6);\n  EXPECT_NEAR(dest.total_seconds(), 0.1, 1e-6);\n\n  UpdateFrom(dest, move_assigned);\n  EXPECT_NEAR(dest.load_seconds(), 0.15, 1e-6);\n  EXPECT_NEAR(dest.total_seconds(), 0.15, 1e-6);\n}\n\nTEST_F(LoadMetricTest, LoadUtilizationTracking) {\n  LoadMetricUpdater updater(start_time_);\n  auto now = start_time_;\n\n  // Start with some load time (load is active by default now)\n  now += std::chrono::milliseconds(100);\n  updater.LoadStop(now);  // Stop load to create idle time\n  LoadMetricProto metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.1, 1e-6);   // 100ms load\n  EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6);  // 100ms total\n\n  // Add some idle time (load is stopped)\n  now += std::chrono::milliseconds(100);\n  updater.LoadStart(now);  // Start load again\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.0, 1e-6);   // No load time\n  EXPECT_NEAR(metric.total_seconds(), 0.1, 1e-6);  // 100ms idle\n\n  // Add more load time\n  now += std::chrono::milliseconds(200);\n  updater.LoadStop(now);\n  metric = updater.FlushMetrics(now);\n  EXPECT_NEAR(metric.load_seconds(), 0.2, 1e-6);   // 200ms load\n  EXPECT_NEAR(metric.total_seconds(), 0.2, 1e-6);  // 200ms total\n\n  // Test complete utilization tracking with one updater\n  LoadMetricUpdater total_updater(start_time_);\n  auto total_now = start_time_;\n\n  // 100ms load (active by default)\n  total_now += std::chrono::milliseconds(100);\n  total_updater.LoadStop(total_now);\n\n  // 100ms idle\n  total_now += std::chrono::milliseconds(100);\n  total_updater.LoadStart(total_now);\n\n  // 200ms load\n  total_now += std::chrono::milliseconds(200);\n  LoadMetricProto total_metric = total_updater.FlushMetrics(total_now);\n\n  // Calculate utilization\n  double utilization =\n      total_metric.load_seconds() / total_metric.total_seconds();\n  EXPECT_NEAR(utilization, 0.75,\n              1e-6);  // 75% utilization (300ms load / 400ms total)\n}\n\nclass LoadMetricProtoIntegrationTest : public ::testing::Test {\n protected:\n  using TestAggregator =\n      ExponentialAggregator<LoadMetricProto, TimePeriod::k16Milliseconds>;\n  using Clock = TestAggregator::Clock;\n\n  void SetUp() override {\n    aggregator_ = std::make_unique<TestAggregator>(\n        [](LoadMetricProto& m) { m.Clear(); },\n        [](LoadMetricProto& dest, const LoadMetricProto& src) {\n          UpdateFrom(dest, src);\n        });\n    start_time_ = Clock::now();\n  }\n\n  std::unique_ptr<TestAggregator> aggregator_;\n  Clock::time_point start_time_;\n};\n\nTEST_F(LoadMetricProtoIntegrationTest, RecordMetricsWithUpdater) {\n  auto current_time = start_time_;\n\n  // Create metric with updater, simulate some load\n  LoadMetricUpdater updater(current_time);\n  updater.LoadStart(current_time);\n  current_time += std::chrono::milliseconds(150);\n\n  // Flush and get metric\n  LoadMetricProto metric = updater.FlushMetrics(current_time);\n\n  // Record the metric (this should use UpdateFrom + Reset)\n  aggregator_->RecordMetrics(std::move(metric));\n\n  // Get live metrics\n  auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(\n      TestAggregator::Duration::zero(), current_time);\n  EXPECT_NEAR(live_metrics.load_seconds(), 0.15, 1e-6);\n}\n\nTEST_F(LoadMetricProtoIntegrationTest, MultipleRecordMetrics) {\n  auto current_time = start_time_;\n\n  // First metric\n  LoadMetricUpdater updater1(current_time);\n  updater1.LoadStart(current_time);\n  current_time += std::chrono::milliseconds(100);\n  LoadMetricProto metric1 = updater1.FlushMetrics(current_time);\n  aggregator_->RecordMetrics(std::move(metric1));\n\n  // Second metric\n  LoadMetricUpdater updater2(current_time);\n  updater2.LoadStart(current_time);\n  current_time += std::chrono::milliseconds(75);\n  LoadMetricProto metric2 = updater2.FlushMetrics(current_time);\n  aggregator_->RecordMetrics(std::move(metric2));\n\n  // Get live metrics\n  auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(\n      TestAggregator::Duration::zero(), current_time);\n  EXPECT_NEAR(live_metrics.load_seconds(), 0.175, 1e-6);  // 0.1 + 0.075\n}\n\nTEST_F(LoadMetricProtoIntegrationTest, AdvanceTest) {\n  auto current_time = start_time_;\n\n  // Add some metrics\n  LoadMetricUpdater updater(current_time);\n  updater.LoadStart(current_time);\n  updater.LoadStop(current_time + std::chrono::milliseconds(100));\n  LoadMetricProto metric =\n      updater.FlushMetrics(current_time + std::chrono::milliseconds(100));\n  aggregator_->RecordMetrics(std::move(metric));\n\n  // Advance to move live metrics to buckets\n  auto tick_time = start_time_ + aggregator_->GetResolution();\n  auto period = aggregator_->Advance(tick_time);\n\n  // Should return the base time period\n  EXPECT_EQ(period, TimePeriod::k16Milliseconds);\n\n  // Live metrics should be empty after advance\n  auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(\n      TestAggregator::Duration::zero(), tick_time);\n  EXPECT_EQ(live_metrics.load_seconds(), 0.0);\n}\n\nTEST_F(LoadMetricTest, LoadStartStopReturnValues) {\n  LoadMetricUpdater updater(start_time_);\n\n  // Load starts active, so LoadStart should return false (already active)\n  EXPECT_FALSE(updater.LoadStart());\n\n  // LoadStop should return true (was active)\n  EXPECT_TRUE(updater.LoadStop());\n\n  // LoadStop again should return false (already stopped)\n  EXPECT_FALSE(updater.LoadStop());\n\n  // LoadStart should return true (was stopped)\n  EXPECT_TRUE(updater.LoadStart());\n\n  // LoadStart again should return false (already active)\n  EXPECT_FALSE(updater.LoadStart());\n}\n\n}  // namespace training\n}  // namespace lczero\n\nint main(int argc, char** argv) {\n  ::testing::InitGoogleTest(&argc, argv);\n  return RUN_ALL_TESTS();\n}"
  },
  {
    "path": "csrc/utils/metrics/printer.h",
    "content": "#pragma once\n\n#include <absl/strings/str_cat.h>\n\n#include <string>\n#include <string_view>\n\nnamespace lczero {\n\nclass MetricPrinter {\n public:\n  virtual ~MetricPrinter() = default;\n  virtual void StartGroup(std::string_view group_name) = 0;\n  virtual void Print(std::string_view metric_name,\n                     const absl::AlphaNum& value) = 0;\n  virtual void EndGroup() = 0;\n};\n\nclass StringMetricPrinter : public MetricPrinter {\n public:\n  StringMetricPrinter(std::string* output) : output_(output) {}\n  void StartGroup(std::string_view group_name) override {\n    if (!first_group_) absl::StrAppend(output_, \", \");\n    absl::StrAppend(output_, group_name, \"={\");\n    first_group_ = false;\n    first_metric_ = true;\n  }\n\n  void Print(std::string_view metric_name,\n             const absl::AlphaNum& value) override {\n    if (!first_metric_) absl::StrAppend(output_, \", \");\n    absl::StrAppend(output_, metric_name, \"=\", value);\n    first_metric_ = false;\n    first_group_ = false;\n  }\n\n  void EndGroup() override { absl::StrAppend(output_, \"}\"); }\n\n private:\n  std::string* output_;\n  bool first_metric_ = true;\n  bool first_group_ = true;\n};\n\ntemplate <typename T>\nstd::string MetricToString(const T& metric) {\n  std::string result;\n  StringMetricPrinter printer(&result);\n  metric.Print(printer);\n  return result;\n}\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/statistics_metric.h",
    "content": "#pragma once\n\n#include <algorithm>\n\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\n\n// Helper function to add a sample to StatisticsProtoInt64\ninline void AddSample(StatisticsProtoInt64& stats, int64_t value) {\n  stats.set_min(std::min(stats.min(), value));\n  stats.set_max(std::max(stats.max(), value));\n  stats.set_sum(stats.sum() + value);\n  stats.set_count(stats.count() + 1);\n  stats.set_latest(value);\n}\n\n// Helper function to add a sample to StatisticsProtoDouble\ninline void AddSample(StatisticsProtoDouble& stats, double value) {\n  stats.set_min(std::min(stats.min(), value));\n  stats.set_max(std::max(stats.max(), value));\n  stats.set_sum(stats.sum() + value);\n  stats.set_count(stats.count() + 1);\n  stats.set_latest(value);\n}\n\n// UpdateFrom function for StatisticsProtoInt64 - merges statistics\ninline void UpdateFrom(StatisticsProtoInt64& dest,\n                       const StatisticsProtoInt64& src) {\n  if (src.count() == 0) return;  // Nothing to merge from empty source\n\n  dest.set_min(std::min(dest.min(), src.min()));\n  dest.set_max(std::max(dest.max(), src.max()));\n  dest.set_sum(dest.sum() + src.sum());\n  dest.set_count(dest.count() + src.count());\n  dest.set_latest(src.latest());  // Source is newer, use its latest value\n}\n\n// UpdateFrom function for StatisticsProtoDouble - merges statistics\ninline void UpdateFrom(StatisticsProtoDouble& dest,\n                       const StatisticsProtoDouble& src) {\n  if (src.count() == 0) return;  // Nothing to merge from empty source\n\n  if (src.has_name()) dest.set_name(src.name());\n  dest.set_min(std::min(dest.min(), src.min()));\n  dest.set_max(std::max(dest.max(), src.max()));\n  dest.set_sum(dest.sum() + src.sum());\n  dest.set_count(dest.count() + src.count());\n  dest.set_latest(src.latest());  // Source is newer, use its latest value\n}\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/metrics/stats_test.cc",
    "content": "#include <gtest/gtest.h>\n\n#include <chrono>\n#include <memory>\n#include <optional>\n#include <thread>\n\n#include \"utils/metrics/exponential_aggregator.h\"\n#include \"utils/metrics/group.h\"\n#include \"utils/metrics/printer.h\"\n\nnamespace lczero {\n\nclass CounterMetric {\n public:\n  CounterMetric() : count_(0) {}\n  CounterMetric(int count) : count_(count) {}\n\n  void Reset() { count_ = 0; }\n\n  void MergeFrom(const CounterMetric& other) { count_ += other.count_; }\n\n  void Print(MetricPrinter& printer) const {\n    printer.StartGroup(\"CounterMetric\");\n    printer.Print(\"count\", static_cast<size_t>(count_));\n    printer.EndGroup();\n  }\n\n  int count() const { return count_; }\n  void set_count(int count) { count_ = count; }\n\n private:\n  int count_;\n};\n\nclass AverageMetric {\n public:\n  AverageMetric() : sum_(0), count_(0) {}\n  AverageMetric(double sum, int count) : sum_(sum), count_(count) {}\n\n  void Reset() {\n    sum_ = 0;\n    count_ = 0;\n  }\n\n  void MergeFrom(const AverageMetric& other) {\n    sum_ += other.sum_;\n    count_ += other.count_;\n  }\n\n  void Print(MetricPrinter& printer) const {\n    printer.StartGroup(\"AverageMetric\");\n    printer.Print(\"sum\", sum_);\n    printer.Print(\"count\", static_cast<size_t>(count_));\n    if (count_ > 0) {\n      printer.Print(\"average\", sum_ / count_);\n    }\n    printer.EndGroup();\n  }\n\n  double average() const { return count_ > 0 ? sum_ / count_ : 0.0; }\n  void add_sample(double value) {\n    sum_ += value;\n    count_++;\n  }\n\n  double sum() const { return sum_; }\n  int count() const { return count_; }\n\n private:\n  double sum_;\n  int count_;\n};\n\nclass MaxMetric {\n public:\n  MaxMetric() : max_value_(0), has_value_(false) {}\n  MaxMetric(double max_value) : max_value_(max_value), has_value_(true) {}\n\n  void Reset() {\n    max_value_ = 0;\n    has_value_ = false;\n  }\n\n  void MergeFrom(const MaxMetric& other) {\n    if (other.has_value_) {\n      if (!has_value_ || other.max_value_ > max_value_) {\n        max_value_ = other.max_value_;\n        has_value_ = true;\n      }\n    }\n  }\n\n  void Print(MetricPrinter& printer) const {\n    printer.StartGroup(\"MaxMetric\");\n    if (has_value_) {\n      printer.Print(\"max_value\", max_value_);\n      printer.Print(\"has_value\", static_cast<size_t>(1));\n    } else {\n      printer.Print(\"has_value\", static_cast<size_t>(0));\n    }\n    printer.EndGroup();\n  }\n\n  double max_value() const { return max_value_; }\n  bool has_value() const { return has_value_; }\n  void set_value(double value) {\n    if (!has_value_ || value > max_value_) {\n      max_value_ = value;\n      has_value_ = true;\n    }\n  }\n\n private:\n  double max_value_;\n  bool has_value_;\n};\n\n// Optional value metric that demonstrates overshadowing behavior\nclass OptionalValueMetric {\n public:\n  OptionalValueMetric() : value_(std::nullopt) {}\n  OptionalValueMetric(int value) : value_(value) {}\n\n  void Reset() { value_ = std::nullopt; }\n\n  void MergeFrom(const OptionalValueMetric& other) {\n    // Only copy the value if the other metric has one (overshadowing behavior)\n    if (other.value_.has_value()) {\n      value_ = other.value_;\n    }\n  }\n\n  void Print(MetricPrinter& printer) const {\n    printer.StartGroup(\"OptionalValueMetric\");\n    if (value_.has_value()) {\n      printer.Print(\"value\", value_.value());\n      printer.Print(\"has_value\", static_cast<size_t>(1));\n    } else {\n      printer.Print(\"has_value\", static_cast<size_t>(0));\n    }\n    printer.EndGroup();\n  }\n\n  std::optional<int> value() const { return value_; }\n  bool has_value() const { return value_.has_value(); }\n  void set_value(int value) { value_ = value; }\n\n private:\n  std::optional<int> value_;\n};\n\n// Test MetricGroup functionality\nclass MetricGroupTest : public ::testing::Test {\n protected:\n  using TestGroup = MetricGroup<CounterMetric, AverageMetric, MaxMetric>;\n  TestGroup group_;\n};\n\nTEST_F(MetricGroupTest, InitialState) {\n  // Test that metrics are initialized in their default state\n  EXPECT_EQ(group_.Get<CounterMetric>().count(), 0);\n  EXPECT_EQ(group_.Get<AverageMetric>().count(), 0);\n  EXPECT_FALSE(group_.Get<MaxMetric>().has_value());\n}\n\nTEST_F(MetricGroupTest, GetMutable) {\n  // Test getting mutable references and modifying them\n  auto* counter = group_.GetMutable<CounterMetric>();\n  counter->set_count(42);\n  EXPECT_EQ(group_.Get<CounterMetric>().count(), 42);\n\n  auto* average = group_.GetMutable<AverageMetric>();\n  average->add_sample(10.0);\n  average->add_sample(20.0);\n  EXPECT_EQ(group_.Get<AverageMetric>().average(), 15.0);\n\n  auto* max_metric = group_.GetMutable<MaxMetric>();\n  max_metric->set_value(100.0);\n  EXPECT_EQ(group_.Get<MaxMetric>().max_value(), 100.0);\n}\n\nTEST_F(MetricGroupTest, Reset) {\n  // Set up some data\n  group_.GetMutable<CounterMetric>()->set_count(42);\n  group_.GetMutable<AverageMetric>()->add_sample(10.0);\n  group_.GetMutable<MaxMetric>()->set_value(100.0);\n\n  // Reset and verify everything is back to initial state\n  group_.Reset();\n\n  EXPECT_EQ(group_.Get<CounterMetric>().count(), 0);\n  EXPECT_EQ(group_.Get<AverageMetric>().count(), 0);\n  EXPECT_FALSE(group_.Get<MaxMetric>().has_value());\n}\n\nTEST_F(MetricGroupTest, MergeFromGroup) {\n  // Set up source group\n  TestGroup other;\n  other.GetMutable<CounterMetric>()->set_count(10);\n  other.GetMutable<AverageMetric>()->add_sample(5.0);\n  other.GetMutable<MaxMetric>()->set_value(50.0);\n\n  // Set up destination group\n  group_.GetMutable<CounterMetric>()->set_count(20);\n  group_.GetMutable<AverageMetric>()->add_sample(15.0);\n  group_.GetMutable<MaxMetric>()->set_value(30.0);\n\n  // Merge\n  group_.MergeFrom(other);\n\n  // Verify results\n  EXPECT_EQ(group_.Get<CounterMetric>().count(), 30);      // 20 + 10\n  EXPECT_EQ(group_.Get<AverageMetric>().average(), 10.0);  // (15 + 5) / 2\n  EXPECT_EQ(group_.Get<MaxMetric>().max_value(), 50.0);    // max(30, 50)\n}\n\nTEST_F(MetricGroupTest, MergeFromSingleMetric) {\n  // Set up initial state\n  group_.GetMutable<CounterMetric>()->set_count(20);\n\n  // Create a single metric to merge\n  CounterMetric counter(15);\n\n  // Merge single metric\n  group_.MergeFrom(counter);\n\n  // Verify result\n  EXPECT_EQ(group_.Get<CounterMetric>().count(), 35);  // 20 + 15\n}\n\nTEST_F(MetricGroupTest, Print) {\n  // Set up data\n  group_.GetMutable<CounterMetric>()->set_count(42);\n  group_.GetMutable<AverageMetric>()->add_sample(10.0);\n  group_.GetMutable<AverageMetric>()->add_sample(20.0);\n  group_.GetMutable<MaxMetric>()->set_value(100.0);\n\n  std::string result = MetricToString(group_);\n\n  // Should contain all metric names and values\n  EXPECT_NE(result.find(\"CounterMetric\"), std::string::npos);\n  EXPECT_NE(result.find(\"count=42\"), std::string::npos);\n  EXPECT_NE(result.find(\"AverageMetric\"), std::string::npos);\n  EXPECT_NE(result.find(\"average=15\"), std::string::npos);  // (10+20)/2\n  EXPECT_NE(result.find(\"MaxMetric\"), std::string::npos);\n  EXPECT_NE(result.find(\"max_value=100\"), std::string::npos);\n}\n\n// Test MetricToString functionality\nclass MetricPrinterTest : public ::testing::Test {};\n\nTEST_F(MetricPrinterTest, StringMetricPrinter) {\n  std::string output;\n  StringMetricPrinter printer(&output);\n\n  printer.StartGroup(\"test_group\");\n  printer.Print(\"metric1\", std::string(\"value1\"));\n  printer.Print(\"metric2\", std::string(\"42\"));\n  printer.EndGroup();\n\n  EXPECT_EQ(output, \"test_group={metric1=value1, metric2=42}\");\n}\n\nTEST_F(MetricPrinterTest, MultipleGroups) {\n  std::string output;\n  StringMetricPrinter printer(&output);\n\n  printer.StartGroup(\"group1\");\n  printer.Print(\"metric1\", std::string(\"value1\"));\n  printer.EndGroup();\n\n  printer.StartGroup(\"group2\");\n  printer.Print(\"metric2\", std::string(\"value2\"));\n  printer.EndGroup();\n\n  EXPECT_EQ(output, \"group1={metric1=value1}, group2={metric2=value2}\");\n}\n\nTEST_F(MetricPrinterTest, EmptyGroup) {\n  std::string output;\n  StringMetricPrinter printer(&output);\n\n  printer.StartGroup(\"empty_group\");\n  printer.EndGroup();\n\n  EXPECT_EQ(output, \"empty_group={}\");\n}\n\nTEST_F(MetricPrinterTest, SizeTOverload) {\n  std::string output;\n  StringMetricPrinter string_printer(&output);\n  MetricPrinter& printer = string_printer;  // Use base class interface\n\n  printer.StartGroup(\"test_group\");\n  printer.Print(\"count\", static_cast<size_t>(123));\n  printer.EndGroup();\n\n  EXPECT_EQ(output, \"test_group={count=123}\");\n}\n\nTEST_F(MetricPrinterTest, MetricToStringFunction) {\n  CounterMetric counter(123);\n  std::string result = MetricToString(counter);\n\n  EXPECT_NE(result.find(\"CounterMetric\"), std::string::npos);\n  EXPECT_NE(result.find(\"123\"), std::string::npos);\n}\n\nclass ExponentialAggregatorTest : public ::testing::Test {\n protected:\n  using TestMetric =\n      MetricGroup<CounterMetric, AverageMetric, OptionalValueMetric>;\n  using TestAggregator =\n      ExponentialAggregator<TestMetric, TimePeriod::k16Milliseconds>;\n\n  void SetUp() override {\n    // Create a fresh aggregator for each test to avoid state contamination\n    aggregator_ = std::make_unique<TestAggregator>();\n    start_time_ = TestAggregator::Clock::now();\n    aggregator_->Reset(start_time_);\n  }\n\n  std::unique_ptr<TestAggregator> aggregator_;\n  TestAggregator::Clock::time_point start_time_;\n};\n\nTEST_F(ExponentialAggregatorTest, RecordMetrics) {\n  TestMetric metric;\n  metric.GetMutable<CounterMetric>()->set_count(10);\n  metric.GetMutable<AverageMetric>()->add_sample(5.0);\n\n  // Update live metrics\n  aggregator_->RecordMetrics(std::move(metric));\n\n  // The original metric should be reset after move\n  EXPECT_EQ(metric.Get<CounterMetric>().count(), 0);\n  EXPECT_EQ(metric.Get<AverageMetric>().count(), 0);\n\n  // Get live metrics to verify they were updated\n  auto [live_metrics, age] =\n      aggregator_->GetAggregateEndingNow(TestAggregator::Duration::zero());\n  EXPECT_EQ(live_metrics.Get<CounterMetric>().count(), 10);\n  EXPECT_EQ(live_metrics.Get<AverageMetric>().average(), 5.0);\n}\n\nTEST_F(ExponentialAggregatorTest, MultipleUpdatesLiveMetrics) {\n  // Update multiple times\n  for (int i = 1; i <= 5; ++i) {\n    TestMetric metric;\n    metric.GetMutable<CounterMetric>()->set_count(i);\n    metric.GetMutable<AverageMetric>()->add_sample(i * 2.0);\n    aggregator_->RecordMetrics(std::move(metric));\n  }\n\n  // Get live metrics\n  auto [live_metrics, age] =\n      aggregator_->GetAggregateEndingNow(TestAggregator::Duration::zero());\n  EXPECT_EQ(live_metrics.Get<CounterMetric>().count(), 15);  // 1+2+3+4+5\n  EXPECT_EQ(live_metrics.Get<AverageMetric>().average(),\n            6.0);  // (2+4+6+8+10)/5\n}\n\nTEST_F(ExponentialAggregatorTest, Advance) {\n  // Add some live metrics\n  TestMetric metric;\n  metric.GetMutable<CounterMetric>()->set_count(10);\n  aggregator_->RecordMetrics(std::move(metric));\n\n  // Advance to move live metrics to buckets\n  auto tick_time = start_time_ + aggregator_->GetResolution();\n  auto period = aggregator_->Advance(tick_time);\n\n  // Should return the base time period\n  EXPECT_EQ(period, TimePeriod::k16Milliseconds);\n\n  // Live metrics should be empty after tick\n  auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(\n      TestAggregator::Duration::zero(), tick_time);\n  EXPECT_EQ(live_metrics.Get<CounterMetric>().count(), 0);\n}\n\nTEST_F(ExponentialAggregatorTest, MultipleAdvances) {\n  // Add metrics and tick multiple times to test bucket management\n  auto current_time = start_time_;\n  for (const auto expected_period : {\n           TimePeriod::k16Milliseconds,\n           TimePeriod::k31Milliseconds,\n           TimePeriod::k16Milliseconds,\n           TimePeriod::k63Milliseconds,\n           TimePeriod::k16Milliseconds,\n           TimePeriod::k31Milliseconds,\n           TimePeriod::k16Milliseconds,\n           TimePeriod::k125Milliseconds,\n       }) {\n    TestMetric metric;\n    metric.GetMutable<CounterMetric>()->set_count(1);\n    aggregator_->RecordMetrics(std::move(metric));\n\n    current_time += aggregator_->GetResolution();\n    auto period = aggregator_->Advance(current_time);\n\n    EXPECT_EQ(period, expected_period);\n  }\n}\n\nTEST_F(ExponentialAggregatorTest, MultipleAdvancesThreeTicks) {\n  // Add metrics and tick multiple times to test bucket management\n  auto current_time = start_time_;\n  for (const auto expected_period : {\n           TimePeriod::k31Milliseconds,\n           TimePeriod::k63Milliseconds,\n           TimePeriod::k125Milliseconds,\n           TimePeriod::k63Milliseconds,\n       }) {\n    TestMetric metric;\n    metric.GetMutable<CounterMetric>()->set_count(1);\n    aggregator_->RecordMetrics(std::move(metric));\n\n    current_time += aggregator_->GetResolution() * 3;\n    auto period = aggregator_->Advance(current_time);\n\n    EXPECT_EQ(period, expected_period);\n  }\n}\n\nTEST_F(ExponentialAggregatorTest, AggregationTest) {\n  TestMetric metric;\n  // We do 37 (0b100101) updates.\n  for (int i = 0; i < 37; ++i) {\n    metric.GetMutable<CounterMetric>()->set_count(i + 200);\n    metric.GetMutable<OptionalValueMetric>()->set_value(i + 100);\n    aggregator_->RecordMetrics(std::move(metric));\n    start_time_ += aggregator_->GetResolution();\n    aggregator_->Advance(start_time_);\n  }\n\n  // One more tick, but we don't advance this time.\n  metric.GetMutable<CounterMetric>()->set_count(1001);\n  metric.GetMutable<OptionalValueMetric>()->set_value(1002);\n  aggregator_->RecordMetrics(std::move(metric));\n\n  const auto kRes = aggregator_->GetResolution();\n  start_time_ += kRes / 3;  // Not a tick yet.\n\n  auto check_completed_bucket = [&](TimePeriod period, int expected_count,\n                                    std::optional<int> expected_value,\n                                    TestAggregator::Duration expected_duration,\n                                    bool include_pending = false) {\n    auto [live_metrics, age] = aggregator_->GetBucketMetrics(\n        period,\n        include_pending ? std::make_optional(start_time_) : std::nullopt);\n    EXPECT_EQ(live_metrics.Get<CounterMetric>().count(), expected_count);\n    EXPECT_EQ(live_metrics.Get<OptionalValueMetric>().has_value(),\n              expected_value.has_value());\n    if (expected_value.has_value()) {\n      EXPECT_EQ(live_metrics.Get<OptionalValueMetric>().value(),\n                expected_value.value());\n    }\n    EXPECT_EQ(age, expected_duration);\n  };\n\n  check_completed_bucket(TimePeriod::k16Milliseconds, 236, 136, kRes * 0);\n  check_completed_bucket(TimePeriod::k31Milliseconds, 234 + 235, 135, kRes);\n  check_completed_bucket(TimePeriod::k63Milliseconds, 232 + 233 + 234 + 235,\n                         135, kRes);\n  check_completed_bucket(TimePeriod::k125Milliseconds,\n                         224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131,\n                         kRes * 5);\n  check_completed_bucket(TimePeriod::k125Milliseconds,\n                         224 + 225 + 226 + 227 + 228 + 229 + 230 + 231, 131,\n                         kRes * 5 + kRes / 3, true);\n  check_completed_bucket(TimePeriod::k250Milliseconds,\n                         216 + 217 + 218 + 219 + 220 + 221 + 222 + 223 + 224 +\n                             225 + 226 + 227 + 228 + 229 + 230 + 231,\n                         131, kRes * 5);\n  check_completed_bucket(TimePeriod::k500Milliseconds,\n                         200 + 201 + 202 + 203 + 204 + 205 + 206 + 207 + 208 +\n                             209 + 210 + 211 + 212 + 213 + 214 + 215 + 216 +\n                             217 + 218 + 219 + 220 + 221 + 222 + 223 + 224 +\n                             225 + 226 + 227 + 228 + 229 + 230 + 231,\n                         131, kRes * 5);\n  check_completed_bucket(TimePeriod::k1Second, 0, std::nullopt, kRes * 37);\n\n  auto check_aggregate = [&](TestAggregator::Duration duration,\n                             int expected_count,\n                             std::optional<int> expected_value,\n                             TestAggregator::Duration expected_duration,\n                             bool include_pending = false) {\n    auto [live_metrics, age] = aggregator_->GetAggregateEndingNow(\n        duration,\n        include_pending ? std::make_optional(start_time_) : std::nullopt);\n    EXPECT_EQ(live_metrics.Get<CounterMetric>().count(), expected_count);\n    EXPECT_EQ(live_metrics.Get<OptionalValueMetric>().has_value(),\n              expected_value.has_value());\n    if (expected_value.has_value()) {\n      EXPECT_EQ(live_metrics.Get<OptionalValueMetric>().value(),\n                expected_value.value());\n    }\n    EXPECT_EQ(age, expected_duration);\n  };\n\n  check_aggregate(TestAggregator::Duration::zero(), 0, std::nullopt, kRes * 0);\n  check_aggregate(TestAggregator::Duration::zero(), 1001, 1002, kRes / 3, true);\n  check_aggregate(kRes / 4, 1001, 1002, kRes / 3, true);\n  check_aggregate(kRes / 10, 236, 136, kRes);\n  check_aggregate(kRes * 45 / 10, 232 + 233 + 234 + 235 + 236, 136, kRes * 5);\n  check_aggregate(kRes * 55 / 10,\n                  224 + 225 + 226 + 227 + 228 + 229 + 230 + 231 + 232 + 233 +\n                      234 + 235 + 236,\n                  136, kRes * (5 + 8));\n}\n\nTEST_F(ExponentialAggregatorTest, ActualVsRequestedTimeCoverage) {\n  // Test that GetAggregateEndingNow returns actual time covered by statistics\n  // rather than requested duration when insufficient historical data exists.\n  // This test recreates the scenario from the existing AggregationTest but\n  // specifically tests the requested vs actual duration behavior.\n\n  const auto kRes = aggregator_->GetResolution();\n\n  // Set up aggregator with several data points like in AggregationTest\n  TestMetric metric;\n  for (int i = 0; i < 10; ++i) {\n    metric.GetMutable<CounterMetric>()->set_count(i + 200);\n    aggregator_->RecordMetrics(std::move(metric));\n    start_time_ += kRes;\n    aggregator_->Advance(start_time_);\n  }\n\n  // Based on AggregationTest line 507: request 4.5 * kRes, get back 5 * kRes\n  // This demonstrates that actual coverage (5 * kRes) can be MORE than\n  // requested (4.5 * kRes) because the aggregator only has specific bucket\n  // sizes available\n  const auto requested_duration = kRes * 45 / 10;  // 4.5 * kRes\n  auto [result_metrics, actual_duration] =\n      aggregator_->GetAggregateEndingNow(requested_duration, std::nullopt);\n\n  // The key test: when requesting 4.5 * kRes, we should get actual time covered\n  // which may be different than the requested amount due to bucket granularity\n  EXPECT_GT(actual_duration, requested_duration);  // Actual > requested\n  EXPECT_GT(actual_duration, std::chrono::nanoseconds::zero());\n\n  // Verify we got some metrics (non-zero count)\n  EXPECT_GT(result_metrics.Get<CounterMetric>().count(), 0);\n\n  // Test that shows the key behavior: when we request more time than available,\n  // we get back only the time that's actually covered by data\n  auto [result_zero, duration_zero] = aggregator_->GetAggregateEndingNow(\n      kRes * 100, std::nullopt);  // Request way more\n\n  // The returned duration should be much less than requested (showing actual vs\n  // requested)\n  const auto huge_request = kRes * 100;\n  EXPECT_LT(duration_zero, huge_request);\n  EXPECT_GT(duration_zero, std::chrono::nanoseconds::zero());\n}\n\nTEST_F(ExponentialAggregatorTest, ExactDurationTest) {\n  // Simple test: add buckets for exactly 5 seconds, request kAllTime,\n  // ensure we get back exactly 5.0 seconds duration (not more)\n\n  auto current_time = start_time_;\n\n  // Add buckets for exactly 5 seconds\n  for (int i = 0; i < 5; ++i) {\n    TestMetric metric;\n    metric.GetMutable<CounterMetric>()->set_count(100 + i);\n    aggregator_->RecordMetrics(std::move(metric));\n    current_time += std::chrono::seconds(1);\n    aggregator_->Advance(current_time);\n  }\n\n  // Request statistics for all time\n  auto [result_metrics, actual_duration] = aggregator_->GetAggregateEndingNow(\n      std::chrono::duration_cast<TestAggregator::Duration>(\n          std::chrono::hours(24 * 365)),  // Request way more than 5 seconds\n      std::nullopt);\n\n  // Should return exactly 5.0 seconds duration (actual time covered)\n  const auto expected_duration = std::chrono::seconds(5);\n  EXPECT_EQ(actual_duration, expected_duration);\n\n  // Should have all our data\n  const int expected_total = 100 + 101 + 102 + 103 + 104;\n  EXPECT_EQ(result_metrics.Get<CounterMetric>().count(), expected_total);\n}\n\n}  // namespace lczero\n\nint main(int argc, char** argv) {\n  ::testing::InitGoogleTest(&argc, argv);\n  return RUN_ALL_TESTS();\n}"
  },
  {
    "path": "csrc/utils/queue.h",
    "content": "#pragma once\n\n#include <algorithm>\n#include <optional>\n#include <stdexcept>\n#include <stop_token>\n\n#include \"absl/base/thread_annotations.h\"\n#include \"absl/container/fixed_array.h\"\n#include \"absl/log/log.h\"\n#include \"absl/synchronization/mutex.h\"\n#include \"absl/types/span.h\"\n\nnamespace lczero {\n\n// Virtual base class for type-erased handling of queues.\nclass QueueBase {\n public:\n  virtual ~QueueBase() = default;\n  virtual size_t Size() const = 0;\n  virtual size_t Capacity() const = 0;\n  virtual bool IsClosed() const = 0;\n  virtual void Close() = 0;\n  virtual size_t GetTotalPutCount(bool reset = false) = 0;\n  virtual size_t GetTotalGetCount(bool reset = false) = 0;\n  virtual size_t GetTotalDropCount(bool reset = false) = 0;\n};\n\n// Exception thrown when queue operations are attempted on a closed queue.\nclass QueueClosedException : public std::runtime_error {\n public:\n  QueueClosedException() : std::runtime_error(\"Queue is closed\") {}\n};\n\n// Exception thrown when queue operation is cancelled via stop_token.\nclass QueueRequestCancelled : public std::runtime_error {\n public:\n  QueueRequestCancelled() : std::runtime_error(\"Queue request cancelled\") {}\n};\n\nenum class OverflowBehavior { BLOCK, DROP_NEW, KEEP_NEWEST };\n\n// Thread-safe fixed-size circular buffer queue with blocking operations.\n// Supports both single and batch put/get operations.\n// The queue automatically closes when all Producer tokens are destroyed.\n// When closed, Put operations throw immediately, but Get operations only throw\n// when the queue becomes empty - allowing consumption of remaining elements.\ntemplate <typename T>\nclass Queue : public QueueBase {\n public:\n  // Backwards-compatible alias to support code referring to\n  // Queue<T>::OverflowBehavior.\n  using OverflowBehavior = ::lczero::OverflowBehavior;\n\n  explicit Queue(size_t capacity,\n                 OverflowBehavior overflow_behavior = OverflowBehavior::BLOCK);\n\n  // RAII token for producers. Queue automatically closes when all producers\n  // are destroyed. All Put operations must go through this class.\n  class Producer {\n   public:\n    explicit Producer(Queue<T>& queue);\n    ~Producer();\n\n    // Move constructor and assignment\n    Producer(Producer&& other) noexcept;\n    Producer& operator=(Producer&& other) noexcept;\n\n    // Disable copy to maintain RAII semantics\n    Producer(const Producer&) = delete;\n    Producer& operator=(const Producer&) = delete;\n\n    // Puts a single element into the queue. Blocks if queue is full.\n    void Put(const T& item, std::stop_token stop_token = {});\n    void Put(T&& item, std::stop_token stop_token = {});\n\n    // Puts multiple elements into the queue. Blocks if not enough space.\n    void Put(absl::Span<const T> items, std::stop_token stop_token = {});\n    void Put(absl::Span<T> items, std::stop_token stop_token = {});\n\n    // Explicitly close this producer, decrementing the producer count\n    void Close();\n\n   private:\n    Queue<T>* queue_;\n  };\n\n  // Creates a new producer token for this queue.\n  Producer CreateProducer();\n\n  // Gets a single element from the queue. Blocks if queue is empty.\n  T Get(std::stop_token stop_token = {});\n\n  // Gets exactly count elements from the queue. Blocks until count elements\n  // available.\n  absl::FixedArray<T> Get(size_t count, std::stop_token stop_token = {});\n\n  // Gets a single element from the queue if available, returns std::nullopt\n  // if empty.\n  std::optional<T> MaybeGet();\n\n  // Returns the current size of the queue.\n  size_t Size() const override;\n\n  // Returns the capacity of the queue.\n  size_t Capacity() const override;\n\n  // Explicitly close the queue, preventing further Put operations.\n  void Close() override;\n\n  // Returns true if the queue is closed.\n  bool IsClosed() const override;\n\n  // Wait until queue has at least the specified amount of free space.\n  void WaitForRoomAtLeast(size_t room, std::stop_token stop_token = {});\n\n  // Wait until queue has at most the specified amount of free space.\n  void WaitForRoomAtMost(size_t room, std::stop_token stop_token = {});\n\n  // Wait until queue has at least the specified number of elements.\n  void WaitForSizeAtLeast(size_t size, std::stop_token stop_token = {});\n\n  // Wait until queue has at most the specified number of elements.\n  void WaitForSizeAtMost(size_t size, std::stop_token stop_token = {});\n\n  // Returns the total number of elements that have been put into the queue.\n  // If reset is true, resets the counter to 0 after returning the value.\n  size_t GetTotalPutCount(bool reset = false) override;\n\n  // Returns the total number of elements that have been retrieved from the\n  // queue. If reset is true, resets the counter to 0 after returning the value.\n  size_t GetTotalGetCount(bool reset = false) override;\n\n  // Returns the total number of elements that have been dropped from the queue.\n  // If reset is true, resets the counter to 0 after returning the value.\n  size_t GetTotalDropCount(bool reset = false) override;\n\n private:\n  friend class Producer;\n\n  const size_t capacity_;\n  const OverflowBehavior overflow_behavior_;\n  absl::FixedArray<T> buffer_ ABSL_GUARDED_BY(mutex_);\n  size_t head_ ABSL_GUARDED_BY(mutex_) = 0;\n  size_t tail_ ABSL_GUARDED_BY(mutex_) = 0;\n  size_t size_ ABSL_GUARDED_BY(mutex_) = 0;\n  size_t producer_count_ ABSL_GUARDED_BY(mutex_) = 0;\n  bool closed_ ABSL_GUARDED_BY(mutex_) = false;\n  size_t total_put_count_ ABSL_GUARDED_BY(mutex_) = 0;\n  size_t total_get_count_ ABSL_GUARDED_BY(mutex_) = 0;\n  size_t total_drop_count_ ABSL_GUARDED_BY(mutex_) = 0;\n\n  mutable absl::Mutex mutex_;\n  absl::CondVar cond_var_;\n\n  // Internal methods for producer management\n  void RemoveProducer();\n\n  // Internal Put methods (called by Producer)\n  void PutInternal(const T& item, std::stop_token stop_token = {});\n  void PutInternal(T&& item, std::stop_token stop_token = {});\n  void PutInternal(absl::Span<const T> items, std::stop_token stop_token = {});\n  void PutInternal(absl::Span<T> items, std::stop_token stop_token = {});\n\n  // Condition predicates for blocking operations\n  bool CanPutOne() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n  bool CanGet() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n\n  // Additional condition predicates for wait functions\n  bool HasRoomAtLeast(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n  bool HasRoomAtMost(size_t room) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n  bool HasSizeAtLeast(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n  bool HasSizeAtMost(size_t size) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n};\n\n// Implementation\n\ntemplate <typename T>\nQueue<T>::Queue(size_t capacity, OverflowBehavior overflow_behavior)\n    : capacity_(capacity),\n      overflow_behavior_(overflow_behavior),\n      buffer_(capacity) {}\n\n// Producer implementation\ntemplate <typename T>\nQueue<T>::Producer::Producer(Queue<T>& queue) : queue_(&queue) {\n  // Producer count is incremented in CreateProducer()\n  VLOG(1) << \"Queue@\" << static_cast<const void*>(queue_) << \" producer@\"\n          << static_cast<const void*>(this) << \" constructed.\";\n}\n\ntemplate <typename T>\nQueue<T>::Producer::~Producer() {\n  if (queue_) {\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(queue_) << \" producer@\"\n            << static_cast<const void*>(this) << \" destructing.\";\n    queue_->RemoveProducer();\n  }\n}\n\ntemplate <typename T>\nQueue<T>::Producer::Producer(Producer&& other) noexcept : queue_(other.queue_) {\n  other.queue_ = nullptr;\n}\n\ntemplate <typename T>\ntypename Queue<T>::Producer& Queue<T>::Producer::operator=(\n    Producer&& other) noexcept {\n  if (this != &other) {\n    if (queue_) {\n      queue_->RemoveProducer();\n    }\n    queue_ = other.queue_;\n    other.queue_ = nullptr;\n  }\n  return *this;\n}\n\ntemplate <typename T>\nvoid Queue<T>::Producer::Put(const T& item, std::stop_token stop_token) {\n  queue_->PutInternal(item, stop_token);\n}\n\ntemplate <typename T>\nvoid Queue<T>::Producer::Put(T&& item, std::stop_token stop_token) {\n  queue_->PutInternal(std::move(item), stop_token);\n}\n\ntemplate <typename T>\nvoid Queue<T>::Producer::Put(absl::Span<const T> items,\n                             std::stop_token stop_token) {\n  queue_->PutInternal(items, stop_token);\n}\n\ntemplate <typename T>\nvoid Queue<T>::Producer::Put(absl::Span<T> items, std::stop_token stop_token) {\n  queue_->PutInternal(items, stop_token);\n}\n\ntemplate <typename T>\nvoid Queue<T>::Producer::Close() {\n  if (queue_) {\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(queue_) << \" producer@\"\n            << static_cast<const void*>(this) << \" close invoked.\";\n    queue_->RemoveProducer();\n    queue_ = nullptr;\n  }\n}\n\n// Queue implementation\ntemplate <typename T>\ntypename Queue<T>::Producer Queue<T>::CreateProducer() {\n  absl::MutexLock lock(&mutex_);\n  if (closed_) throw QueueClosedException();\n  ++producer_count_;\n  return Producer(*this);\n}\n\ntemplate <typename T>\nvoid Queue<T>::RemoveProducer() {\n  absl::MutexLock lock(&mutex_);\n  --producer_count_;\n  if (producer_count_ == 0 && !closed_) {\n    closed_ = true;\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n            << \" closed after last producer removed.\";\n    cond_var_.SignalAll();\n  }\n}\n\ntemplate <typename T>\nvoid Queue<T>::PutInternal(const T& item, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  if (closed_) {\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n            << \" PutInternal(const&) throwing QueueClosedException;\"\n            << \" producers=\" << producer_count_;\n    throw QueueClosedException();\n  }\n  ++total_put_count_;\n\n  switch (overflow_behavior_) {\n    case OverflowBehavior::BLOCK: {\n      std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n      while (!CanPutOne()) {\n        if (closed_) throw QueueClosedException();\n        if (stop_token.stop_requested()) throw QueueRequestCancelled();\n        cond_var_.Wait(&mutex_);\n      }\n      if (closed_) throw QueueClosedException();\n      break;\n    }\n    case OverflowBehavior::DROP_NEW:\n      if (size_ >= capacity_) {\n        ++total_drop_count_;\n        return;\n      }\n      break;\n    case OverflowBehavior::KEEP_NEWEST:\n      if (size_ >= capacity_) {\n        head_ = (head_ + 1) % capacity_;\n        --size_;\n        ++total_drop_count_;\n      }\n      break;\n  }\n\n  buffer_[tail_] = item;\n  tail_ = (tail_ + 1) % capacity_;\n  ++size_;\n  cond_var_.SignalAll();\n}\n\ntemplate <typename T>\nvoid Queue<T>::PutInternal(T&& item, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  if (closed_) {\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n            << \" PutInternal(T&&) throwing QueueClosedException;\"\n            << \" producers=\" << producer_count_;\n    throw QueueClosedException();\n  }\n  ++total_put_count_;\n\n  switch (overflow_behavior_) {\n    case OverflowBehavior::BLOCK: {\n      std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n      while (!CanPutOne()) {\n        if (closed_) throw QueueClosedException();\n        if (stop_token.stop_requested()) throw QueueRequestCancelled();\n        cond_var_.Wait(&mutex_);\n      }\n      if (closed_) throw QueueClosedException();\n      break;\n    }\n    case OverflowBehavior::DROP_NEW:\n      if (size_ >= capacity_) {\n        ++total_drop_count_;\n        return;\n      }\n      break;\n    case OverflowBehavior::KEEP_NEWEST:\n      if (size_ >= capacity_) {\n        head_ = (head_ + 1) % capacity_;\n        --size_;\n        ++total_drop_count_;\n      }\n      break;\n  }\n\n  buffer_[tail_] = std::move(item);\n  tail_ = (tail_ + 1) % capacity_;\n  ++size_;\n  cond_var_.SignalAll();\n}\n\ntemplate <typename T>\nvoid Queue<T>::PutInternal(absl::Span<const T> items,\n                           std::stop_token stop_token) {\n  if (items.empty()) return;\n\n  size_t remaining = items.size();\n  size_t offset = 0;\n\n  while (remaining > 0) {\n    absl::MutexLock lock(&mutex_);\n    if (closed_) {\n      VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n              << \" PutInternal(span const) throwing QueueClosedException;\"\n              << \" producers=\" << producer_count_;\n      throw QueueClosedException();\n    }\n\n    size_t batch_size;\n    switch (overflow_behavior_) {\n      case OverflowBehavior::BLOCK: {\n        std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n        while (!CanPutOne()) {\n          if (closed_) throw QueueClosedException();\n          if (stop_token.stop_requested()) throw QueueRequestCancelled();\n          cond_var_.Wait(&mutex_);\n        }\n        if (closed_) throw QueueClosedException();\n        batch_size = std::min(remaining, capacity_ - size_);\n        break;\n      }\n      case OverflowBehavior::DROP_NEW:\n        batch_size = std::min(remaining, capacity_ - size_);\n        if (batch_size == 0) {\n          total_put_count_ += remaining;\n          total_drop_count_ += remaining;\n          return;\n        }\n        break;\n      case OverflowBehavior::KEEP_NEWEST:\n        batch_size = std::min(remaining, capacity_);\n        while (size_ + batch_size > capacity_) {\n          head_ = (head_ + 1) % capacity_;\n          --size_;\n          ++total_drop_count_;\n        }\n        break;\n    }\n\n    for (size_t i = 0; i < batch_size; ++i) {\n      buffer_[tail_] = items[offset + i];\n      tail_ = (tail_ + 1) % capacity_;\n      ++size_;\n    }\n    total_put_count_ += batch_size;\n    cond_var_.SignalAll();\n\n    offset += batch_size;\n    remaining -= batch_size;\n  }\n}\n\ntemplate <typename T>\nvoid Queue<T>::PutInternal(absl::Span<T> items, std::stop_token stop_token) {\n  if (items.empty()) return;\n\n  size_t remaining = items.size();\n  size_t offset = 0;\n\n  while (remaining > 0) {\n    absl::MutexLock lock(&mutex_);\n    if (closed_) {\n      VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n              << \" PutInternal(span) throwing QueueClosedException;\"\n              << \" producers=\" << producer_count_;\n      throw QueueClosedException();\n    }\n\n    size_t batch_size;\n    switch (overflow_behavior_) {\n      case OverflowBehavior::BLOCK: {\n        std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n        while (!CanPutOne()) {\n          if (closed_) throw QueueClosedException();\n          if (stop_token.stop_requested()) throw QueueRequestCancelled();\n          cond_var_.Wait(&mutex_);\n        }\n        if (closed_) throw QueueClosedException();\n        batch_size = std::min(remaining, capacity_ - size_);\n        break;\n      }\n      case OverflowBehavior::DROP_NEW:\n        batch_size = std::min(remaining, capacity_ - size_);\n        if (batch_size == 0) {\n          total_put_count_ += remaining;\n          total_drop_count_ += remaining;\n          return;\n        }\n        break;\n      case OverflowBehavior::KEEP_NEWEST:\n        batch_size = std::min(remaining, capacity_);\n        while (size_ + batch_size > capacity_) {\n          head_ = (head_ + 1) % capacity_;\n          --size_;\n          ++total_drop_count_;\n        }\n        break;\n    }\n\n    for (size_t i = 0; i < batch_size; ++i) {\n      buffer_[tail_] = std::move(items[offset + i]);\n      tail_ = (tail_ + 1) % capacity_;\n      ++size_;\n    }\n    total_put_count_ += batch_size;\n    cond_var_.SignalAll();\n\n    offset += batch_size;\n    remaining -= batch_size;\n  }\n}\n\ntemplate <typename T>\nT Queue<T>::Get(std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n  while (!CanGet()) {\n    if (closed_ && size_ == 0) {\n      VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n              << \" Get() throwing QueueClosedException; producers=\"\n              << producer_count_;\n      throw QueueClosedException();\n    }\n    if (stop_token.stop_requested()) throw QueueRequestCancelled();\n    cond_var_.Wait(&mutex_);\n  }\n  if (closed_ && size_ == 0) {\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n            << \" Get() throwing QueueClosedException; producers=\"\n            << producer_count_;\n    throw QueueClosedException();\n  }\n\n  T item = std::move(buffer_[head_]);\n  head_ = (head_ + 1) % capacity_;\n  --size_;\n  ++total_get_count_;\n  cond_var_.SignalAll();\n\n  return item;\n}\n\ntemplate <typename T>\nabsl::FixedArray<T> Queue<T>::Get(size_t count, std::stop_token stop_token) {\n  if (count == 0) return absl::FixedArray<T>(0);\n\n  absl::FixedArray<T> result(count);\n  size_t remaining = count;\n  size_t offset = 0;\n\n  while (remaining > 0) {\n    absl::MutexLock lock(&mutex_);\n    std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n    while (!CanGet()) {\n      if (closed_ && size_ == 0) {\n        VLOG(1) << \"Queue@\" << static_cast<const void*>(this) << \" Get(\"\n                << count << \") throwing QueueClosedException; producers=\"\n                << producer_count_;\n        throw QueueClosedException();\n      }\n      if (stop_token.stop_requested()) throw QueueRequestCancelled();\n      cond_var_.Wait(&mutex_);\n    }\n    if (closed_ && size_ == 0) {\n      VLOG(1) << \"Queue@\" << static_cast<const void*>(this) << \" Get(\" << count\n              << \") throwing QueueClosedException; producers=\"\n              << producer_count_;\n      throw QueueClosedException();\n    }\n\n    size_t batch_size = std::min(remaining, size_);\n\n    for (size_t i = 0; i < batch_size; ++i) {\n      result[offset + i] = std::move(buffer_[head_]);\n      head_ = (head_ + 1) % capacity_;\n      --size_;\n      ++total_get_count_;\n    }\n    cond_var_.SignalAll();\n\n    offset += batch_size;\n    remaining -= batch_size;\n  }\n\n  return result;\n}\n\ntemplate <typename T>\nstd::optional<T> Queue<T>::MaybeGet() {\n  absl::MutexLock lock(&mutex_);\n  if (size_ == 0) return std::nullopt;\n\n  T item = std::move(buffer_[head_]);\n  head_ = (head_ + 1) % capacity_;\n  --size_;\n  ++total_get_count_;\n  cond_var_.SignalAll();\n\n  return item;\n}\n\ntemplate <typename T>\nsize_t Queue<T>::Size() const {\n  absl::MutexLock lock(&mutex_);\n  return size_;\n}\n\ntemplate <typename T>\nsize_t Queue<T>::Capacity() const {\n  return capacity_;\n}\n\ntemplate <typename T>\nvoid Queue<T>::Close() {\n  absl::MutexLock lock(&mutex_);\n  if (!closed_) {\n    closed_ = true;\n    VLOG(1) << \"Queue@\" << static_cast<const void*>(this)\n            << \" closed explicitly; producers=\" << producer_count_;\n    cond_var_.SignalAll();\n  }\n}\n\ntemplate <typename T>\nbool Queue<T>::IsClosed() const {\n  absl::MutexLock lock(&mutex_);\n  return closed_;\n}\n\ntemplate <typename T>\nvoid Queue<T>::WaitForRoomAtLeast(size_t room, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n  while (!HasRoomAtLeast(room)) {\n    if (closed_) throw QueueClosedException();\n    if (stop_token.stop_requested()) throw QueueRequestCancelled();\n    cond_var_.Wait(&mutex_);\n  }\n}\n\ntemplate <typename T>\nvoid Queue<T>::WaitForRoomAtMost(size_t room, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n  while (!HasRoomAtMost(room)) {\n    if (closed_) throw QueueClosedException();\n    if (stop_token.stop_requested()) throw QueueRequestCancelled();\n    cond_var_.Wait(&mutex_);\n  }\n}\n\ntemplate <typename T>\nvoid Queue<T>::WaitForSizeAtLeast(size_t size, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n  while (!HasSizeAtLeast(size)) {\n    if (closed_) throw QueueClosedException();\n    if (stop_token.stop_requested()) throw QueueRequestCancelled();\n    cond_var_.Wait(&mutex_);\n  }\n}\n\ntemplate <typename T>\nvoid Queue<T>::WaitForSizeAtMost(size_t size, std::stop_token stop_token) {\n  absl::MutexLock lock(&mutex_);\n  std::stop_callback cb(stop_token, [this]() { cond_var_.SignalAll(); });\n  while (!HasSizeAtMost(size)) {\n    if (closed_) throw QueueClosedException();\n    if (stop_token.stop_requested()) throw QueueRequestCancelled();\n    cond_var_.Wait(&mutex_);\n  }\n}\n\ntemplate <typename T>\nbool Queue<T>::CanPutOne() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return closed_ || size_ < capacity_;\n}\n\ntemplate <typename T>\nbool Queue<T>::CanGet() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return closed_ || size_ > 0;\n}\n\ntemplate <typename T>\nbool Queue<T>::HasRoomAtLeast(size_t room)\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return capacity_ - size_ >= room;\n}\n\ntemplate <typename T>\nbool Queue<T>::HasRoomAtMost(size_t room)\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return capacity_ - size_ <= room;\n}\n\ntemplate <typename T>\nbool Queue<T>::HasSizeAtLeast(size_t size)\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return size_ >= size;\n}\n\ntemplate <typename T>\nbool Queue<T>::HasSizeAtMost(size_t size)\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  return size_ <= size;\n}\n\ntemplate <typename T>\nsize_t Queue<T>::GetTotalPutCount(bool reset) {\n  absl::MutexLock lock(&mutex_);\n  size_t count = total_put_count_;\n  if (reset) total_put_count_ = 0;\n  return count;\n}\n\ntemplate <typename T>\nsize_t Queue<T>::GetTotalGetCount(bool reset) {\n  absl::MutexLock lock(&mutex_);\n  size_t count = total_get_count_;\n  if (reset) total_get_count_ = 0;\n  return count;\n}\n\ntemplate <typename T>\nsize_t Queue<T>::GetTotalDropCount(bool reset) {\n  absl::MutexLock lock(&mutex_);\n  size_t count = total_drop_count_;\n  if (reset) total_drop_count_ = 0;\n  return count;\n}\n\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/utils/queue_test.cc",
    "content": "// ABOUTME: Comprehensive unit tests for the Queue template class\n// ABOUTME: Tests thread-safe operations, blocking behavior, and edge cases\n\n#include \"utils/queue.h\"\n\n#include <gtest/gtest.h>\n\n#include <atomic>\n#include <chrono>\n#include <future>\n#include <thread>\n#include <vector>\n\nnamespace lczero {\n\nclass QueueTest : public ::testing::Test {\n protected:\n  void SetUp() override {}\n};\n\n// Basic functionality tests\n\nTEST_F(QueueTest, ConstructorCreatesEmptyQueue) {\n  Queue<int> queue(5);\n  EXPECT_EQ(queue.Size(), 0);\n  EXPECT_EQ(queue.Capacity(), 5);\n}\n\nTEST_F(QueueTest, SinglePutGet) {\n  Queue<int> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(42);\n    EXPECT_EQ(queue.Size(), 1);\n  }  // Producer destroyed here, queue closes\n\n  int value = queue.Get();\n  EXPECT_EQ(value, 42);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, MovePutGet) {\n  Queue<std::unique_ptr<int>> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    auto ptr = std::make_unique<int>(42);\n    producer.Put(std::move(ptr));\n    EXPECT_EQ(queue.Size(), 1);\n  }  // Producer destroyed here, queue closes\n\n  auto result = queue.Get();\n  EXPECT_EQ(*result, 42);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, MultiplePutGet) {\n  Queue<int> queue(5);\n\n  {\n    auto producer = queue.CreateProducer();\n    for (int i = 0; i < 5; ++i) {\n      producer.Put(i);\n    }\n    EXPECT_EQ(queue.Size(), 5);\n  }  // Producer destroyed here, queue closes\n\n  for (int i = 0; i < 5; ++i) {\n    int value = queue.Get();\n    EXPECT_EQ(value, i);\n  }\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, CircularBufferBehavior) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n\n  // Get one item, put another\n  EXPECT_EQ(queue.Get(), 1);\n  producer.Put(4);\n\n  // Verify remaining items\n  EXPECT_EQ(queue.Get(), 2);\n  EXPECT_EQ(queue.Get(), 3);\n  EXPECT_EQ(queue.Get(), 4);\n}\n\n// Batch operations tests\n\nTEST_F(QueueTest, BatchPutConstSpan) {\n  Queue<int> queue(5);\n  std::vector<int> items = {1, 2, 3};\n\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(absl::Span<const int>(items));\n    EXPECT_EQ(queue.Size(), 3);\n  }  // Producer destroyed here, queue closes\n\n  for (int i = 0; i < 3; ++i) {\n    EXPECT_EQ(queue.Get(), i + 1);\n  }\n}\n\nTEST_F(QueueTest, BatchPutMoveSpan) {\n  Queue<std::unique_ptr<int>> queue(5);\n  std::vector<std::unique_ptr<int>> items;\n  items.push_back(std::make_unique<int>(1));\n  items.push_back(std::make_unique<int>(2));\n  items.push_back(std::make_unique<int>(3));\n\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(absl::Span<std::unique_ptr<int>>(items));\n    EXPECT_EQ(queue.Size(), 3);\n  }  // Producer destroyed here, queue closes\n\n  for (int i = 0; i < 3; ++i) {\n    auto result = queue.Get();\n    EXPECT_EQ(*result, i + 1);\n  }\n}\n\nTEST_F(QueueTest, BatchPutEmptySpan) {\n  Queue<int> queue(5);\n  std::vector<int> empty_items;\n\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(absl::Span<const int>(empty_items));\n    EXPECT_EQ(queue.Size(), 0);\n  }  // Producer destroyed here, queue closes\n}\n\nTEST_F(QueueTest, BatchGet) {\n  Queue<int> queue(5);\n\n  {\n    auto producer = queue.CreateProducer();\n    for (int i = 0; i < 5; ++i) {\n      producer.Put(i);\n    }\n  }  // Producer destroyed here, queue closes\n\n  auto result = queue.Get(3);\n  EXPECT_EQ(result.size(), 3);\n  for (int i = 0; i < 3; ++i) {\n    EXPECT_EQ(result[i], i);\n  }\n  EXPECT_EQ(queue.Size(), 2);\n}\n\nTEST_F(QueueTest, BatchGetZeroCount) {\n  Queue<int> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(42);\n  }  // Producer destroyed here, queue closes\n\n  auto result = queue.Get(0);\n  EXPECT_EQ(result.size(), 0);\n  EXPECT_EQ(queue.Size(), 1);\n}\n\n// Edge cases and error conditions\n\nTEST_F(QueueTest, CapacityOne) {\n  Queue<int> queue(1);\n\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(42);\n    EXPECT_EQ(queue.Size(), 1);\n  }  // Producer destroyed here, queue closes\n\n  EXPECT_EQ(queue.Get(), 42);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\n// Tests for operations when all producer tokens are destroyed\n\nTEST_F(QueueTest, CreateProducerOnClosedQueue) {\n  Queue<int> queue(5);\n  // Create and immediately destroy producer to close queue\n  {\n    auto producer = queue.CreateProducer();\n  }\n\n  // Trying to create a new producer after queue is closed results in an\n  // exception.\n  EXPECT_THROW(queue.CreateProducer(), QueueClosedException);\n}\n\nTEST_F(QueueTest, GetOnClosedQueue) {\n  Queue<int> queue(5);\n  // Create and immediately destroy producer to close queue\n  {\n    auto producer = queue.CreateProducer();\n  }\n\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\nTEST_F(QueueTest, BatchGetOnClosedQueue) {\n  Queue<int> queue(5);\n  // Create and immediately destroy producer to close queue\n  {\n    auto producer = queue.CreateProducer();\n  }\n\n  EXPECT_THROW(queue.Get(3), QueueClosedException);\n}\n\n// Thread safety tests\n\nTEST_F(QueueTest, SingleProducerSingleConsumer) {\n  Queue<int> queue(10);\n  std::atomic<bool> producer_done{false};\n  std::vector<int> consumed;\n\n  std::thread producer([&queue, &producer_done]() {\n    auto prod = queue.CreateProducer();\n    for (int i = 0; i < 100; ++i) {\n      prod.Put(i);\n    }\n    producer_done = true;\n    // Producer destroyed here, closing the queue\n  });\n\n  std::thread consumer([&queue, &consumed, &producer_done]() {\n    int value;\n    while (!producer_done || queue.Size() > 0) {\n      try {\n        value = queue.Get();\n        consumed.push_back(value);\n      } catch (const QueueClosedException&) {\n        break;\n      }\n    }\n  });\n\n  producer.join();\n  consumer.join();\n\n  EXPECT_EQ(consumed.size(), 100);\n  for (int i = 0; i < 100; ++i) {\n    EXPECT_EQ(consumed[i], i);\n  }\n}\n\nTEST_F(QueueTest, MultipleProducersMultipleConsumers) {\n  Queue<int> queue(10);\n  constexpr int num_producers = 2;\n  constexpr int items_per_producer = 5;\n  constexpr int total_items = num_producers * items_per_producer;\n\n  std::vector<int> all_consumed;\n  std::vector<std::thread> producers;\n\n  // Use a single producer token that we control explicitly\n  auto producer_token = queue.CreateProducer();\n\n  // Start producers - they all share the same producer token via reference\n  for (int p = 0; p < num_producers; ++p) {\n    producers.emplace_back([&producer_token, p]() {\n      for (int i = 0; i < items_per_producer; ++i) {\n        int value = p * items_per_producer + i;\n        producer_token.Put(value);\n      }\n    });\n  }\n\n  // Wait for all producers to finish\n  for (auto& producer : producers) {\n    producer.join();\n  }\n\n  // Now explicitly close the queue by destroying the producer token\n  {\n    auto temp = std::move(producer_token);\n  }  // Queue is now closed\n\n  // Now consume all items from the closed queue\n  for (int i = 0; i < total_items; ++i) {\n    all_consumed.push_back(queue.Get());\n  }\n\n  // Verify all items were consumed\n  EXPECT_EQ(all_consumed.size(), total_items);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Trying to get one more should throw\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\nTEST_F(QueueTest, BlockingBehaviorOnFullQueue) {\n  Queue<int> queue(2);\n  std::promise<void> about_to_block;\n  std::future<void> about_to_block_future = about_to_block.get_future();\n  std::atomic<bool> put_completed{false};\n  auto producer = queue.CreateProducer();\n\n  // Fill the queue\n  producer.Put(1);\n  producer.Put(2);\n\n  std::thread blocker([&producer, &about_to_block, &put_completed]() {\n    about_to_block.set_value();  // Signal we're about to block\n    producer.Put(3);             // This should block\n    put_completed = true;\n  });\n\n  // Wait for thread to signal it's about to block\n  about_to_block_future.wait();\n  EXPECT_FALSE(put_completed);\n\n  // Make space in the queue\n  EXPECT_EQ(queue.Get(), 1);\n\n  blocker.join();\n  EXPECT_TRUE(put_completed);\n  EXPECT_EQ(queue.Size(), 2);\n}\n\nTEST_F(QueueTest, BlockingBehaviorOnEmptyQueue) {\n  Queue<int> queue(5);\n  std::promise<void> about_to_block;\n  std::future<void> about_to_block_future = about_to_block.get_future();\n  std::atomic<bool> get_completed{false};\n  std::atomic<int> result{-1};\n  auto producer = queue.CreateProducer();\n\n  std::thread blocker([&queue, &about_to_block, &get_completed, &result]() {\n    about_to_block.set_value();  // Signal we're about to block\n    result = queue.Get();        // This should block\n    get_completed = true;\n  });\n\n  // Wait for thread to signal it's about to block\n  about_to_block_future.wait();\n  EXPECT_FALSE(get_completed);\n\n  // Put an item in the queue\n  producer.Put(42);\n\n  blocker.join();\n  EXPECT_TRUE(get_completed);\n  EXPECT_EQ(result, 42);\n}\n\nTEST_F(QueueTest, ProducerDestructionUnblocksWaitingGet) {\n  Queue<int> queue(5);  // Empty queue\n\n  std::promise<void> about_to_block;\n  std::future<void> about_to_block_future = about_to_block.get_future();\n  std::atomic<bool> exception_thrown{false};\n\n  // Create a producer to keep queue open initially\n  std::unique_ptr<Queue<int>::Producer> producer =\n      std::make_unique<Queue<int>::Producer>(queue.CreateProducer());\n\n  std::thread blocker([&queue, &about_to_block, &exception_thrown]() {\n    about_to_block.set_value();  // Signal we're about to block\n    try {\n      queue.Get();  // This should block\n    } catch (const QueueClosedException&) {\n      exception_thrown = true;\n    }\n  });\n\n  // Wait for thread to signal it's about to block\n  about_to_block_future.wait();\n  EXPECT_FALSE(exception_thrown);\n\n  // Destroy the producer - this should close queue and unblock the waiting\n  // Get()\n  producer.reset();\n\n  blocker.join();\n  EXPECT_TRUE(exception_thrown);\n}\n\n// Test: Get() should not throw when queue is closed but has elements\nTEST_F(QueueTest, GetFromClosedQueueWithElements) {\n  Queue<int> queue(5);\n\n  // Put some elements in the queue, then destroy producer to close it\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(1);\n    producer.Put(2);\n    producer.Put(3);\n    EXPECT_EQ(queue.Size(), 3);\n  }  // Producer destroyed here, queue closes\n\n  // Should be able to get elements that were already in the queue\n  EXPECT_EQ(queue.Get(), 1);\n  EXPECT_EQ(queue.Get(), 2);\n  EXPECT_EQ(queue.Get(), 3);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Only now should Get() throw when queue is empty and closed\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\nTEST_F(QueueTest, BatchGetFromClosedQueueWithElements) {\n  Queue<int> queue(5);\n\n  // Put some elements in the queue, then destroy producer to close it\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(1);\n    producer.Put(2);\n    producer.Put(3);\n    EXPECT_EQ(queue.Size(), 3);\n  }  // Producer destroyed here, queue closes\n\n  // Should be able to get elements that were already in the queue\n  auto result = queue.Get(2);\n  EXPECT_EQ(result.size(), 2);\n  EXPECT_EQ(result[0], 1);\n  EXPECT_EQ(result[1], 2);\n  EXPECT_EQ(queue.Size(), 1);\n\n  // Get remaining element\n  EXPECT_EQ(queue.Get(), 3);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Only now should Get() throw when queue is empty and closed\n  EXPECT_THROW(queue.Get(1), QueueClosedException);\n}\n\n// Test producer token mechanism specifically\nTEST_F(QueueTest, ProducerTokenMechanism) {\n  Queue<int> queue(5);\n\n  // Create multiple producers\n  auto producer1 = queue.CreateProducer();\n  auto producer2 = queue.CreateProducer();\n\n  // Both should be able to put items\n  producer1.Put(1);\n  producer2.Put(2);\n  EXPECT_EQ(queue.Size(), 2);\n\n  // Destroy one producer - queue should still be open\n  {\n    auto temp = std::move(producer1);\n  }  // producer1 is destroyed here\n  producer2.Put(3);\n  EXPECT_EQ(queue.Size(), 3);\n\n  // Destroy last producer - queue should close\n  {\n    auto temp = std::move(producer2);\n  }  // producer2 is destroyed here\n\n  // Should still be able to get existing items\n  EXPECT_EQ(queue.Get(), 1);\n  EXPECT_EQ(queue.Get(), 2);\n  EXPECT_EQ(queue.Get(), 3);\n\n  // But trying to get more should throw\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\nTEST_F(QueueTest, ProducerMoveSemantics) {\n  Queue<int> queue(5);\n\n  auto producer1 = queue.CreateProducer();\n  producer1.Put(42);\n\n  // Move constructor\n  auto producer2 = std::move(producer1);\n  producer2.Put(43);\n  EXPECT_EQ(queue.Size(), 2);\n\n  // Create another producer and use move assignment\n  auto producer3 = queue.CreateProducer();\n  producer3 = std::move(producer2);\n  producer3.Put(44);\n  EXPECT_EQ(queue.Size(), 3);\n\n  // Destroy the last producer\n  {\n    auto temp = std::move(producer3);\n  }  // producer3 is destroyed here\n\n  // Should be able to get all items\n  EXPECT_EQ(queue.Get(), 42);\n  EXPECT_EQ(queue.Get(), 43);\n  EXPECT_EQ(queue.Get(), 44);\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\n// Tests for Put operations on closed queue\nTEST_F(QueueTest, PutOnClosedQueueThrowsException) {\n  Queue<int> queue(5);\n\n  // Create producer and close it\n  auto producer = queue.CreateProducer();\n  queue.Close();\n\n  // All Put operations should throw on closed queue\n  EXPECT_THROW(producer.Put(42), QueueClosedException);\n  EXPECT_THROW(producer.Put(std::move(42)), QueueClosedException);\n\n  std::vector<int> items = {1, 2, 3};\n  EXPECT_THROW(producer.Put(absl::Span<const int>(items)),\n               QueueClosedException);\n  EXPECT_THROW(producer.Put(absl::Span<int>(items)), QueueClosedException);\n}\n\nTEST_F(QueueTest, PutOnClosedQueueAfterProducerDestruction) {\n  Queue<int> queue(5);\n\n  // Create producer, add item, then close by destroying all producers\n  auto producer = queue.CreateProducer();\n  producer.Put(1);\n  {\n    auto temp_producer = std::move(producer);\n  }  // All producers destroyed, queue closed\n\n  // Try to create new producer after close\n  EXPECT_THROW(queue.CreateProducer(), QueueClosedException);\n}\n\nTEST_F(QueueTest, BatchPutOnClosedQueueThrowsException) {\n  Queue<int> queue(10);\n\n  auto producer = queue.CreateProducer();\n  queue.Close();\n\n  // Batch put operations should throw on closed queue\n  std::vector<int> items = {1, 2, 3, 4, 5};\n  EXPECT_THROW(producer.Put(absl::Span<const int>(items)),\n               QueueClosedException);\n\n  std::vector<int> mutable_items = {6, 7, 8};\n  EXPECT_THROW(producer.Put(absl::Span<int>(mutable_items)),\n               QueueClosedException);\n}\n\nTEST_F(QueueTest, PublicCloseMethod) {\n  Queue<int> queue(5);\n\n  auto producer = queue.CreateProducer();\n  producer.Put(1);\n  producer.Put(2);\n\n  // Explicitly close the queue using public Close() method\n  queue.Close();\n\n  // Put operations should now throw\n  EXPECT_THROW(producer.Put(3), QueueClosedException);\n\n  // But Get operations should still work for existing items\n  EXPECT_EQ(queue.Get(), 1);\n  EXPECT_EQ(queue.Get(), 2);\n\n  // Get should throw when queue is empty and closed\n  EXPECT_THROW(queue.Get(), QueueClosedException);\n}\n\nTEST_F(QueueTest, CloseUnblocksWaitingSinglePut) {\n  Queue<int> queue(2);  // Small capacity\n  auto producer = queue.CreateProducer();\n\n  // Fill the queue\n  producer.Put(1);\n  producer.Put(2);\n\n  std::promise<void> about_to_block;\n  std::future<void> about_to_block_future = about_to_block.get_future();\n  std::atomic<bool> exception_thrown{false};\n\n  std::thread blocker([&producer, &about_to_block, &exception_thrown]() {\n    about_to_block.set_value();  // Signal we're about to block\n    try {\n      producer.Put(3);  // This should block since queue is full\n    } catch (const QueueClosedException&) {\n      exception_thrown = true;\n    }\n  });\n\n  // Wait for thread to signal it's about to block\n  about_to_block_future.wait();\n  EXPECT_FALSE(exception_thrown);\n\n  // Close the queue - this should unblock the waiting Put()\n  queue.Close();\n\n  blocker.join();\n  EXPECT_TRUE(exception_thrown);\n}\n\nTEST_F(QueueTest, CloseUnblocksWaitingBatchPut) {\n  Queue<int> queue(3);  // Small capacity\n  auto producer = queue.CreateProducer();\n\n  // Fill the queue partially\n  producer.Put(1);\n  producer.Put(2);\n\n  std::promise<void> about_to_block;\n  std::future<void> about_to_block_future = about_to_block.get_future();\n  std::atomic<bool> exception_thrown{false};\n\n  std::thread blocker([&producer, &about_to_block, &exception_thrown]() {\n    about_to_block.set_value();  // Signal we're about to block\n    try {\n      std::vector<int> items = {3, 4, 5};  // Need 3 slots but only 1 available\n      producer.Put(absl::Span<const int>(items));  // This should block\n    } catch (const QueueClosedException&) {\n      exception_thrown = true;\n    }\n  });\n\n  // Wait for thread to signal it's about to block\n  about_to_block_future.wait();\n  EXPECT_FALSE(exception_thrown);\n\n  // Close the queue - this should unblock the waiting batch Put()\n  queue.Close();\n\n  blocker.join();\n  EXPECT_TRUE(exception_thrown);\n}\n\n// Tests for new wait functions\nTEST_F(QueueTest, WaitForRoomAtLeast) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  // Initially empty queue should have room >= 5\n  queue.WaitForRoomAtLeast(5);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Fill queue partially\n  producer.Put(1);\n  producer.Put(2);\n\n  // Should have room >= 3\n  queue.WaitForRoomAtLeast(3);\n  EXPECT_EQ(queue.Size(), 2);\n\n  // Fill more\n  producer.Put(3);\n  producer.Put(4);\n\n  // Should have room >= 1\n  queue.WaitForRoomAtLeast(1);\n  EXPECT_EQ(queue.Size(), 4);\n\n  // Test blocking behavior\n  producer.Put(5);  // Queue is now full\n\n  std::promise<void> wait_started;\n  std::future<void> wait_started_future = wait_started.get_future();\n  std::atomic<bool> wait_completed{false};\n\n  std::thread waiter([&queue, &wait_started, &wait_completed]() {\n    wait_started.set_value();\n    queue.WaitForRoomAtLeast(2);  // Should block until 2 slots are free\n    wait_completed = true;\n  });\n\n  wait_started_future.wait();\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  // Free up space\n  queue.Get();\n  queue.Get();\n\n  waiter.join();\n  EXPECT_TRUE(wait_completed);\n}\n\nTEST_F(QueueTest, WaitForRoomAtMost) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue partially\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n\n  // Should wait until room <= 2 (currently room = 2)\n  queue.WaitForRoomAtMost(2);\n  EXPECT_EQ(queue.Size(), 3);\n\n  // Test blocking behavior\n  std::promise<void> wait_started;\n  std::future<void> wait_started_future = wait_started.get_future();\n  std::atomic<bool> wait_completed{false};\n\n  std::thread waiter([&queue, &wait_started, &wait_completed]() {\n    wait_started.set_value();\n    queue.WaitForRoomAtMost(1);  // Should block until room <= 1\n    wait_completed = true;\n  });\n\n  wait_started_future.wait();\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  // Add one more item to make room = 1\n  producer.Put(4);\n\n  waiter.join();\n  EXPECT_TRUE(wait_completed);\n  EXPECT_EQ(queue.Size(), 4);\n}\n\nTEST_F(QueueTest, WaitForSizeAtLeast) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  // Test blocking behavior on empty queue\n  std::promise<void> wait_started;\n  std::future<void> wait_started_future = wait_started.get_future();\n  std::atomic<bool> wait_completed{false};\n\n  std::thread waiter([&queue, &wait_started, &wait_completed]() {\n    wait_started.set_value();\n    queue.WaitForSizeAtLeast(3);  // Should block until size >= 3\n    wait_completed = true;\n  });\n\n  wait_started_future.wait();\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  // Add items\n  producer.Put(1);\n  producer.Put(2);\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  producer.Put(3);  // Now size = 3\n\n  waiter.join();\n  EXPECT_TRUE(wait_completed);\n  EXPECT_EQ(queue.Size(), 3);\n}\n\nTEST_F(QueueTest, WaitForSizeAtMost) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  // Initially empty, size <= 3\n  queue.WaitForSizeAtMost(3);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Fill queue\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  producer.Put(4);\n  producer.Put(5);\n\n  // Test blocking behavior\n  std::promise<void> wait_started;\n  std::future<void> wait_started_future = wait_started.get_future();\n  std::atomic<bool> wait_completed{false};\n\n  std::thread waiter([&queue, &wait_started, &wait_completed]() {\n    wait_started.set_value();\n    queue.WaitForSizeAtMost(2);  // Should block until size <= 2\n    wait_completed = true;\n  });\n\n  wait_started_future.wait();\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  // Remove items\n  queue.Get();\n  queue.Get();\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n  EXPECT_FALSE(wait_completed);\n\n  queue.Get();  // Now size = 2\n\n  waiter.join();\n  EXPECT_TRUE(wait_completed);\n  EXPECT_EQ(queue.Size(), 2);\n}\n\nTEST_F(QueueTest, WaitFunctionsEdgeCases) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  // Wait for room = 0 should work when queue is full\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  queue.WaitForRoomAtMost(0);\n  EXPECT_EQ(queue.Size(), 3);\n\n  // Wait for size = 0 should work when queue is empty\n  queue.Get();\n  queue.Get();\n  queue.Get();\n  queue.WaitForSizeAtMost(0);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Wait for room >= capacity should always succeed\n  queue.WaitForRoomAtLeast(3);\n  EXPECT_EQ(queue.Size(), 0);\n\n  // Wait for size >= 0 should always succeed\n  queue.WaitForSizeAtLeast(0);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\n// Tests for gradual large range operations\n\nTEST_F(QueueTest, BatchPutAtCapacityWorks) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  // Putting exactly capacity worth of items should work\n  std::vector<int> items = {1, 2, 3};\n  producer.Put(absl::Span<const int>(items));\n  EXPECT_EQ(queue.Size(), 3);\n}\n\nTEST_F(QueueTest, BatchGetAtCapacityWorks) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n\n  // Getting exactly capacity worth of items should work\n  auto result = queue.Get(3);\n  EXPECT_EQ(result.size(), 3);\n  EXPECT_EQ(result[0], 1);\n  EXPECT_EQ(result[1], 2);\n  EXPECT_EQ(result[2], 3);\n}\n\nTEST_F(QueueTest, LargeRangePutGetGradual) {\n  Queue<int> queue(3);  // Small capacity\n  auto producer = queue.CreateProducer();\n\n  // Put more items than capacity - should work gradually\n  std::vector<int> large_items = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};\n\n  std::thread put_thread([&producer, &large_items]() {\n    producer.Put(absl::Span<const int>(large_items));\n  });\n\n  // Consume items as they become available\n  std::vector<int> consumed;\n  for (int i = 0; i < 10; ++i) {\n    consumed.push_back(queue.Get());\n  }\n\n  put_thread.join();\n\n  // Verify all items were transferred correctly\n  EXPECT_EQ(consumed.size(), 10);\n  for (int i = 0; i < 10; ++i) {\n    EXPECT_EQ(consumed[i], i + 1);\n  }\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, LargeRangePutMove) {\n  Queue<std::unique_ptr<int>> queue(2);  // Very small capacity\n  auto producer = queue.CreateProducer();\n\n  // Create large batch of move-only items\n  std::vector<std::unique_ptr<int>> large_items;\n  for (int i = 1; i <= 5; ++i) {\n    large_items.push_back(std::make_unique<int>(i));\n  }\n\n  std::thread put_thread([&producer, &large_items]() {\n    producer.Put(absl::Span<std::unique_ptr<int>>(large_items));\n  });\n\n  // Consume items as they become available\n  std::vector<std::unique_ptr<int>> consumed;\n  for (int i = 0; i < 5; ++i) {\n    consumed.push_back(queue.Get());\n  }\n\n  put_thread.join();\n\n  // Verify all items were transferred correctly\n  EXPECT_EQ(consumed.size(), 5);\n  for (int i = 0; i < 5; ++i) {\n    EXPECT_EQ(*consumed[i], i + 1);\n  }\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, LargeRangeGetGradual) {\n  Queue<int> queue(3);  // Small capacity\n  auto producer = queue.CreateProducer();\n\n  // Start a thread that will gradually produce items\n  std::thread producer_thread([&producer]() {\n    for (int i = 1; i <= 10; ++i) {\n      producer.Put(i);\n      std::this_thread::sleep_for(std::chrono::milliseconds(1));\n    }\n  });\n\n  // Get more items than capacity - should work gradually\n  auto result = queue.Get(10);\n\n  producer_thread.join();\n\n  // Verify all items were retrieved correctly\n  EXPECT_EQ(result.size(), 10);\n  for (int i = 0; i < 10; ++i) {\n    EXPECT_EQ(result[i], i + 1);\n  }\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, LargeRangePutGetConcurrent) {\n  Queue<int> queue(5);  // Medium capacity\n\n  constexpr int total_items = 100;\n  constexpr int batch_size = 25;\n\n  auto producer1 = queue.CreateProducer();\n  auto producer2 = queue.CreateProducer();\n\n  std::vector<int> batch1, batch2;\n  for (int i = 0; i < batch_size; ++i) {\n    batch1.push_back(i);\n    batch2.push_back(i + batch_size);\n  }\n\n  std::atomic<int> items_consumed{0};\n  std::vector<int> all_consumed;\n  all_consumed.reserve(total_items);\n\n  // Multiple producers\n  std::thread producer_thread1([&producer1, &batch1]() {\n    producer1.Put(absl::Span<const int>(batch1));\n    producer1.Put(absl::Span<const int>(batch1));  // Put twice\n  });\n\n  std::thread producer_thread2([&producer2, &batch2]() {\n    producer2.Put(absl::Span<const int>(batch2));\n    producer2.Put(absl::Span<const int>(batch2));  // Put twice\n  });\n\n  // Consumer getting in large batches\n  std::thread consumer_thread([&queue, &all_consumed, &items_consumed]() {\n    while (items_consumed < total_items) {\n      try {\n        auto batch = queue.Get(std::min(15, total_items - items_consumed));\n        for (const auto& item : batch) {\n          all_consumed.push_back(item);\n        }\n        items_consumed += batch.size();\n      } catch (const QueueClosedException&) {\n        break;\n      }\n    }\n  });\n\n  producer_thread1.join();\n  producer_thread2.join();\n\n  // Close the queue by destroying producers\n  producer1.Close();\n  producer2.Close();\n\n  consumer_thread.join();\n\n  EXPECT_EQ(all_consumed.size(), total_items);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, GradualOperationsWithQueueClosure) {\n  Queue<int> queue(2);  // Very small capacity\n  auto producer = queue.CreateProducer();\n\n  std::vector<int> large_batch = {1, 2, 3, 4, 5};\n  std::atomic<bool> exception_caught{false};\n\n  std::thread producer_thread([&producer, &large_batch, &exception_caught]() {\n    try {\n      producer.Put(absl::Span<const int>(large_batch));\n    } catch (const QueueClosedException&) {\n      exception_caught = true;\n    }\n  });\n\n  // Let producer start putting items\n  std::this_thread::sleep_for(std::chrono::milliseconds(10));\n\n  // Consume a couple items\n  queue.Get();  // Should get 1\n  queue.Get();  // Should get 2\n\n  // Close the queue while producer is still trying to put items\n  queue.Close();\n\n  producer_thread.join();\n\n  EXPECT_TRUE(exception_caught);\n  // Queue might have some items that were put before closure\n  // but we can't predict exactly how many due to timing\n}\n\n// Tests for total put count functionality\n\nTEST_F(QueueTest, GetTotalPutCountBasic) {\n  Queue<int> queue(5);\n  EXPECT_EQ(queue.GetTotalPutCount(), 0);\n\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(1);\n    EXPECT_EQ(queue.GetTotalPutCount(), 1);\n\n    producer.Put(2);\n    producer.Put(3);\n    EXPECT_EQ(queue.GetTotalPutCount(), 3);\n  }\n\n  // Count should persist after producer destruction\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n\n  // Count should persist after getting items\n  queue.Get();\n  queue.Get();\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n}\n\nTEST_F(QueueTest, GetTotalPutCountBatch) {\n  Queue<int> queue(10);\n  auto producer = queue.CreateProducer();\n\n  std::vector<int> batch1 = {1, 2, 3};\n  std::vector<int> batch2 = {4, 5};\n\n  producer.Put(absl::Span<const int>(batch1));\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n\n  producer.Put(absl::Span<const int>(batch2));\n  EXPECT_EQ(queue.GetTotalPutCount(), 5);\n\n  // Single put after batch\n  producer.Put(6);\n  EXPECT_EQ(queue.GetTotalPutCount(), 6);\n}\n\nTEST_F(QueueTest, GetTotalPutCountReset) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n\n  // Reset and verify return value\n  EXPECT_EQ(queue.GetTotalPutCount(true), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 0);\n\n  // Add more items\n  producer.Put(4);\n  EXPECT_EQ(queue.GetTotalPutCount(), 1);\n\n  // Non-reset call should not affect counter\n  EXPECT_EQ(queue.GetTotalPutCount(false), 1);\n  EXPECT_EQ(queue.GetTotalPutCount(), 1);\n}\n\nTEST_F(QueueTest, GetTotalPutCountThreadSafe) {\n  Queue<int> queue(50);  // Large capacity to avoid blocking\n  constexpr int items_per_thread = 10;\n  constexpr int num_threads = 2;\n\n  std::vector<std::thread> threads;\n  std::vector<Queue<int>::Producer> producers;\n\n  // Create producers for each thread\n  for (int t = 0; t < num_threads; ++t) {\n    producers.push_back(queue.CreateProducer());\n  }\n\n  for (int t = 0; t < num_threads; ++t) {\n    threads.emplace_back([&producers, t]() {\n      for (int i = 0; i < items_per_thread; ++i) {\n        producers[t].Put(t * items_per_thread + i);\n      }\n    });\n  }\n\n  for (auto& thread : threads) {\n    thread.join();\n  }\n\n  EXPECT_EQ(queue.GetTotalPutCount(), items_per_thread * num_threads);\n}\n\nTEST_F(QueueTest, GetTotalPutCountBatchThreadSafe) {\n  Queue<int> queue(100);  // Large capacity to avoid blocking\n\n  std::vector<std::thread> threads;\n  std::vector<Queue<int>::Producer> producers;\n\n  // Create producers for each thread\n  for (int t = 0; t < 2; ++t) {\n    producers.push_back(queue.CreateProducer());\n  }\n\n  for (int t = 0; t < 2; ++t) {\n    threads.emplace_back([&producers, t]() {\n      std::vector<int> batch;\n      int batch_size = (t + 1) * 5;  // 5, 10 items\n      for (int i = 0; i < batch_size; ++i) {\n        batch.push_back(t * 100 + i);\n      }\n      producers[t].Put(absl::Span<const int>(batch));\n    });\n  }\n\n  for (auto& thread : threads) {\n    thread.join();\n  }\n\n  EXPECT_EQ(queue.GetTotalPutCount(), 15);  // 5 + 10\n}\n\nTEST_F(QueueTest, GetTotalPutCountWithMoveSemantics) {\n  Queue<std::unique_ptr<int>> queue(5);\n  auto producer = queue.CreateProducer();\n\n  // Single move put\n  auto ptr1 = std::make_unique<int>(42);\n  producer.Put(std::move(ptr1));\n  EXPECT_EQ(queue.GetTotalPutCount(), 1);\n\n  // Batch move put\n  std::vector<std::unique_ptr<int>> batch;\n  for (int i = 0; i < 3; ++i) {\n    batch.push_back(std::make_unique<int>(i));\n  }\n  producer.Put(absl::Span<std::unique_ptr<int>>(batch));\n  EXPECT_EQ(queue.GetTotalPutCount(), 4);\n}\n\nTEST_F(QueueTest, GetTotalPutCountEmptyBatch) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  std::vector<int> empty_batch;\n  producer.Put(absl::Span<const int>(empty_batch));\n  EXPECT_EQ(queue.GetTotalPutCount(), 0);\n\n  producer.Put(1);\n  EXPECT_EQ(queue.GetTotalPutCount(), 1);\n\n  producer.Put(absl::Span<const int>(empty_batch));\n  EXPECT_EQ(queue.GetTotalPutCount(), 1);\n}\n\n// Tests for DROP_NEW overflow behavior\n\nTEST_F(QueueTest, DropNewBasicBehavior) {\n  Queue<int> queue(3, Queue<int>::OverflowBehavior::DROP_NEW);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue to capacity\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n\n  // Additional puts should be dropped\n  producer.Put(4);\n  producer.Put(5);\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 5);\n  EXPECT_EQ(queue.GetTotalDropCount(), 2);\n\n  // Verify original items are still there\n  EXPECT_EQ(queue.Get(), 1);\n  EXPECT_EQ(queue.Get(), 2);\n  EXPECT_EQ(queue.Get(), 3);\n}\n\nTEST_F(QueueTest, DropNewBatchBehavior) {\n  Queue<int> queue(3, Queue<int>::OverflowBehavior::DROP_NEW);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue partially\n  producer.Put(1);\n  EXPECT_EQ(queue.Size(), 1);\n\n  // Try to put more than capacity allows\n  std::vector<int> large_batch = {2, 3, 4, 5, 6};\n  producer.Put(absl::Span<const int>(large_batch));\n\n  // Only first 2 should fit\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 6);   // 1 + 5 attempted\n  EXPECT_EQ(queue.GetTotalDropCount(), 3);  // 4, 5, 6 were dropped\n\n  // Verify what's in the queue\n  EXPECT_EQ(queue.Get(), 1);\n  EXPECT_EQ(queue.Get(), 2);\n  EXPECT_EQ(queue.Get(), 3);\n}\n\nTEST_F(QueueTest, DropNewThreadSafety) {\n  Queue<int> queue(5, Queue<int>::OverflowBehavior::DROP_NEW);\n  constexpr int num_threads = 3;\n  constexpr int items_per_thread = 10;\n\n  std::vector<std::thread> threads;\n  std::vector<Queue<int>::Producer> producers;\n\n  for (int t = 0; t < num_threads; ++t) {\n    producers.push_back(queue.CreateProducer());\n  }\n\n  for (int t = 0; t < num_threads; ++t) {\n    threads.emplace_back([&producers, t]() {\n      for (int i = 0; i < items_per_thread; ++i) {\n        producers[t].Put(t * items_per_thread + i);\n      }\n    });\n  }\n\n  for (auto& thread : threads) {\n    thread.join();\n  }\n\n  // Queue should have at most capacity items\n  EXPECT_LE(queue.Size(), 5);\n  // All puts should be counted\n  EXPECT_EQ(queue.GetTotalPutCount(), num_threads * items_per_thread);\n  // Some items should have been dropped\n  EXPECT_GT(queue.GetTotalDropCount(), 0);\n  // Put count = successful puts + drops\n  EXPECT_EQ(queue.GetTotalPutCount(), queue.Size() + queue.GetTotalDropCount());\n}\n\n// Tests for KEEP_NEWEST overflow behavior\n\nTEST_F(QueueTest, KeepNewestBasicBehavior) {\n  Queue<int> queue(3, Queue<int>::OverflowBehavior::KEEP_NEWEST);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue to capacity\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 3);\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n\n  // Additional puts should replace oldest items\n  producer.Put(4);\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 4);\n  EXPECT_EQ(queue.GetTotalDropCount(), 1);\n\n  producer.Put(5);\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 5);\n  EXPECT_EQ(queue.GetTotalDropCount(), 2);\n\n  // Verify newest items are kept (3, 4, 5)\n  EXPECT_EQ(queue.Get(), 3);\n  EXPECT_EQ(queue.Get(), 4);\n  EXPECT_EQ(queue.Get(), 5);\n}\n\nTEST_F(QueueTest, KeepNewestBatchBehavior) {\n  Queue<int> queue(3, Queue<int>::OverflowBehavior::KEEP_NEWEST);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n\n  // Put large batch that exceeds capacity\n  std::vector<int> large_batch = {4, 5, 6, 7, 8};\n  producer.Put(absl::Span<const int>(large_batch));\n\n  // Queue should still have capacity items\n  EXPECT_EQ(queue.Size(), 3);\n  EXPECT_EQ(queue.GetTotalPutCount(), 8);\n  EXPECT_EQ(queue.GetTotalDropCount(), 5);  // 1, 2, 3, 4, 5 dropped\n\n  // Should have the newest 3 items (6, 7, 8)\n  EXPECT_EQ(queue.Get(), 6);\n  EXPECT_EQ(queue.Get(), 7);\n  EXPECT_EQ(queue.Get(), 8);\n}\n\nTEST_F(QueueTest, KeepNewestLargeBatch) {\n  Queue<int> queue(2, Queue<int>::OverflowBehavior::KEEP_NEWEST);\n  auto producer = queue.CreateProducer();\n\n  // Put batch larger than capacity\n  std::vector<int> large_batch = {1, 2, 3, 4, 5};\n  producer.Put(absl::Span<const int>(large_batch));\n\n  // Should keep only the last 2 items\n  EXPECT_EQ(queue.Size(), 2);\n  EXPECT_EQ(queue.GetTotalPutCount(), 5);\n  EXPECT_EQ(queue.GetTotalDropCount(), 3);\n\n  EXPECT_EQ(queue.Get(), 4);\n  EXPECT_EQ(queue.Get(), 5);\n}\n\n// Tests for counter functionality\n\nTEST_F(QueueTest, GetTotalGetCountBasic) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  EXPECT_EQ(queue.GetTotalGetCount(), 0);\n\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  EXPECT_EQ(queue.GetTotalGetCount(), 0);\n\n  queue.Get();\n  EXPECT_EQ(queue.GetTotalGetCount(), 1);\n\n  queue.Get();\n  queue.Get();\n  EXPECT_EQ(queue.GetTotalGetCount(), 3);\n}\n\nTEST_F(QueueTest, GetTotalGetCountBatch) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n  producer.Put(4);\n  producer.Put(5);\n\n  queue.Get(3);\n  EXPECT_EQ(queue.GetTotalGetCount(), 3);\n\n  queue.Get();\n  EXPECT_EQ(queue.GetTotalGetCount(), 4);\n\n  queue.Get(1);\n  EXPECT_EQ(queue.GetTotalGetCount(), 5);\n}\n\nTEST_F(QueueTest, GetTotalGetCountReset) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);\n\n  queue.Get();\n  queue.Get();\n  EXPECT_EQ(queue.GetTotalGetCount(), 2);\n\n  EXPECT_EQ(queue.GetTotalGetCount(true), 2);\n  EXPECT_EQ(queue.GetTotalGetCount(), 0);\n}\n\nTEST_F(QueueTest, GetTotalDropCountBasic) {\n  Queue<int> queue(2, Queue<int>::OverflowBehavior::DROP_NEW);\n  auto producer = queue.CreateProducer();\n\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n\n  producer.Put(1);\n  producer.Put(2);\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n\n  producer.Put(3);  // Should be dropped\n  EXPECT_EQ(queue.GetTotalDropCount(), 1);\n\n  producer.Put(4);  // Should be dropped\n  EXPECT_EQ(queue.GetTotalDropCount(), 2);\n}\n\nTEST_F(QueueTest, GetTotalDropCountKeepNewest) {\n  Queue<int> queue(2, Queue<int>::OverflowBehavior::KEEP_NEWEST);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n\n  producer.Put(3);  // Should drop 1\n  EXPECT_EQ(queue.GetTotalDropCount(), 1);\n\n  producer.Put(4);  // Should drop 2\n  EXPECT_EQ(queue.GetTotalDropCount(), 2);\n}\n\nTEST_F(QueueTest, GetTotalDropCountReset) {\n  Queue<int> queue(1, Queue<int>::OverflowBehavior::DROP_NEW);\n  auto producer = queue.CreateProducer();\n\n  producer.Put(1);\n  producer.Put(2);  // Dropped\n  producer.Put(3);  // Dropped\n  EXPECT_EQ(queue.GetTotalDropCount(), 2);\n\n  EXPECT_EQ(queue.GetTotalDropCount(true), 2);\n  EXPECT_EQ(queue.GetTotalDropCount(), 0);\n}\n\n// MaybeGet() tests\n\nTEST_F(QueueTest, MaybeGetOnEmptyQueue) {\n  Queue<int> queue(5);\n  auto result = queue.MaybeGet();\n  EXPECT_FALSE(result.has_value());\n}\n\nTEST_F(QueueTest, MaybeGetOnNonEmptyQueue) {\n  Queue<int> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(42);\n  }\n\n  auto result = queue.MaybeGet();\n  ASSERT_TRUE(result.has_value());\n  EXPECT_EQ(*result, 42);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, MaybeGetMultipleValues) {\n  Queue<int> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(1);\n    producer.Put(2);\n    producer.Put(3);\n  }\n\n  auto result1 = queue.MaybeGet();\n  ASSERT_TRUE(result1.has_value());\n  EXPECT_EQ(*result1, 1);\n\n  auto result2 = queue.MaybeGet();\n  ASSERT_TRUE(result2.has_value());\n  EXPECT_EQ(*result2, 2);\n\n  auto result3 = queue.MaybeGet();\n  ASSERT_TRUE(result3.has_value());\n  EXPECT_EQ(*result3, 3);\n\n  auto result4 = queue.MaybeGet();\n  EXPECT_FALSE(result4.has_value());\n}\n\nTEST_F(QueueTest, MaybeGetWithMoveOnlyType) {\n  Queue<std::unique_ptr<int>> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(std::make_unique<int>(42));\n  }\n\n  auto result = queue.MaybeGet();\n  ASSERT_TRUE(result.has_value());\n  ASSERT_NE(*result, nullptr);\n  EXPECT_EQ(**result, 42);\n}\n\nTEST_F(QueueTest, MaybeGetUpdatesGetCount) {\n  Queue<int> queue(5);\n  {\n    auto producer = queue.CreateProducer();\n    producer.Put(1);\n    producer.Put(2);\n  }\n\n  EXPECT_EQ(queue.GetTotalGetCount(), 0);\n\n  queue.MaybeGet();\n  EXPECT_EQ(queue.GetTotalGetCount(), 1);\n\n  queue.MaybeGet();\n  EXPECT_EQ(queue.GetTotalGetCount(), 2);\n\n  queue.MaybeGet();                        // Empty queue.\n  EXPECT_EQ(queue.GetTotalGetCount(), 2);  // Count not incremented.\n}\n\n// Tests for stop_token cancellation\n\nTEST_F(QueueTest, StopTokenCancelsPut) {\n  Queue<int> queue(2);\n  auto producer = queue.CreateProducer();\n\n  // Fill the queue\n  producer.Put(1);\n  producer.Put(2);\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  EXPECT_THROW(producer.Put(3, stop_source.get_token()), QueueRequestCancelled);\n  EXPECT_EQ(queue.Size(), 2);\n}\n\nTEST_F(QueueTest, StopTokenCancelsGet) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  EXPECT_THROW(queue.Get(stop_source.get_token()), QueueRequestCancelled);\n  EXPECT_EQ(queue.Size(), 0);\n}\n\nTEST_F(QueueTest, StopTokenCancelsBatchPut) {\n  Queue<int> queue(2);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue completely\n  producer.Put(1);\n  producer.Put(2);\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  std::vector<int> items = {3, 4, 5};\n  EXPECT_THROW(\n      producer.Put(absl::Span<const int>(items), stop_source.get_token()),\n      QueueRequestCancelled);\n}\n\nTEST_F(QueueTest, StopTokenCancelsBatchGet) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  EXPECT_THROW(queue.Get(10, stop_source.get_token()), QueueRequestCancelled);\n}\n\nTEST_F(QueueTest, StopTokenCancelsWaitForRoomAtLeast) {\n  Queue<int> queue(3);\n  auto producer = queue.CreateProducer();\n\n  // Fill queue\n  producer.Put(1);\n  producer.Put(2);\n  producer.Put(3);\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  EXPECT_THROW(queue.WaitForRoomAtLeast(2, stop_source.get_token()),\n               QueueRequestCancelled);\n}\n\nTEST_F(QueueTest, StopTokenCancelsWaitForSizeAtLeast) {\n  Queue<int> queue(5);\n  auto producer = queue.CreateProducer();\n\n  std::stop_source stop_source;\n  stop_source.request_stop();\n\n  // Should immediately throw without blocking since token is already stopped\n  EXPECT_THROW(queue.WaitForSizeAtLeast(3, stop_source.get_token()),\n               QueueRequestCancelled);\n}\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/stream_shuffler.cc",
    "content": "#include \"utils/stream_shuffler.h\"\n\nnamespace lczero {\nnamespace training {\n\nvoid StreamShuffler::SetUpperBound(size_t upper_bound) {\n  assert(upper_bound >= upper_bound_);\n  stream_size_ += upper_bound - upper_bound_;\n  while (upper_bound_ < upper_bound) {\n    if (buckets_.empty() || buckets_.back().GetRemainingCapacity() == 0) {\n      buckets_.emplace_back(upper_bound_, bucket_size_);\n    }\n    upper_bound_ = std::min(\n        upper_bound, upper_bound_ + buckets_.back().GetRemainingCapacity());\n    buckets_.back().Extend(upper_bound_);\n  }\n}\n\nvoid StreamShuffler::SetLowerBound(size_t lower_bound) {\n  assert(lower_bound >= lower_bound_);\n  lower_bound_ = lower_bound;\n  if (lower_bound >= upper_bound_) {\n    upper_bound_ = lower_bound;\n    stream_size_ = 0;\n    buckets_.clear();\n    return;\n  }\n  while (!buckets_.empty() && buckets_.front().upper_bound() <= lower_bound_) {\n    stream_size_ -= buckets_.front().size();\n    buckets_.pop_front();\n  }\n  if (!buckets_.empty()) {\n    auto old_size = buckets_.front().size();\n    buckets_.front().DeclareLowerBound(lower_bound_);\n    stream_size_ -= old_size - buckets_.front().size();\n  }\n}\n\nstd::optional<size_t> StreamShuffler::GetNextItem() {\n  auto try_fetch = [&]() -> size_t {\n    size_t item_idx = absl::Uniform(gen_, size_t{0}, stream_size_);\n    --stream_size_;\n    for (auto& bucket : buckets_) {\n      if (item_idx < bucket.size()) return bucket.Fetch(item_idx);\n      item_idx -= bucket.size();\n    }\n    throw std::logic_error(\"StreamShuffler: item index out of bounds\");\n  };\n\n  while (stream_size_ > 0) {\n    if (auto item = try_fetch(); item >= lower_bound_) return item;\n  }\n\n  return std::nullopt;\n}\n\nvoid StreamShuffler::Reset(size_t lower_bound, size_t upper_bound) {\n  // Reset all internal state\n  buckets_.clear();\n  stream_size_ = 0;\n  upper_bound_ = lower_bound;\n  lower_bound_ = lower_bound;\n\n  // Establish the bounds, which will build the buckets with fresh data\n  if (upper_bound > lower_bound) {\n    SetUpperBound(upper_bound);\n  }\n}\n\nStreamShuffler::Bucket::Bucket(size_t lower_bound, size_t capacity)\n    : upper_bound_(lower_bound), items_(capacity) {}\n\nsize_t StreamShuffler::Bucket::GetRemainingCapacity() const {\n  return items_.size() - items_count_;\n}\n\nvoid StreamShuffler::Bucket::Extend(size_t new_upper_bound) {\n  assert(new_upper_bound >= upper_bound_);\n  const size_t increase = new_upper_bound - upper_bound_;\n  assert(increase <= GetRemainingCapacity());\n  std::iota(items_.begin() + items_count_,\n            items_.begin() + items_count_ + increase, upper_bound_);\n  items_count_ += increase;\n  upper_bound_ = new_upper_bound;\n}\n\nsize_t StreamShuffler::Bucket::Fetch(size_t item_idx) {\n  assert(item_idx < items_count_);\n  size_t item = items_[item_idx];\n  std::swap(items_[item_idx], items_[--items_count_]);\n  return item;\n}\n\nvoid DeclareLowerBound(size_t new_lower_bound);\nvoid StreamShuffler::Bucket::DeclareLowerBound(size_t new_lower_bound) {\n  if (upper_bound_ - new_lower_bound < 2 * items_count_) return;\n\n  // If the bucket has much more items that the allowed range, there are many\n  // items out of the range. It makes sense to sort and remove them.\n  std::sort(items_.begin(), items_.begin() + items_count_,\n            std::greater<size_t>());\n  // Find the first item that is under the new lower bound.\n  auto it = std::upper_bound(items_.begin(), items_.begin() + items_count_,\n                             new_lower_bound, std::greater<size_t>());\n  items_count_ = it - items_.begin();\n}\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/stream_shuffler.h",
    "content": "#pragma once\n\n#include <absl/container/fixed_array.h>\n#include <absl/random/random.h>\n\n#include <cstddef>\n#include <deque>\n#include <numeric>\n#include <optional>\n\nnamespace lczero {\nnamespace training {\n\n// Returns a number between [lower_bound, upper_bound) in shuffled order.\n// Both bounds can be changed at any time, and the stream will adapt\n// accordingly. Not thread-safe.\nclass StreamShuffler {\n public:\n  // Sets the upper bound (exclusive). Can only be increased.\n  void SetUpperBound(size_t upper_bound);\n\n  // Sets the lower bound (inclusive). Can only be increased.\n  void SetLowerBound(size_t lower_bound);\n\n  // Sets the bucket size for internal storage optimization.\n  void SetBucketSize(size_t bucket_size) { bucket_size_ = bucket_size; }\n\n  // Returns the next item in shuffled order, or nullopt if exhausted.\n  std::optional<size_t> GetNextItem();\n\n  // Resets the shuffler to restart iteration with specified bounds.\n  void Reset(size_t lower_bound, size_t upper_bound);\n\n private:\n  class Bucket {\n   public:\n    Bucket(size_t lower_bound, size_t capacity);\n    size_t GetRemainingCapacity() const;\n    void Extend(size_t new_upper_bound);\n    size_t Fetch(size_t item_idx);\n    void DeclareLowerBound(size_t new_lower_bound);\n\n    size_t upper_bound() const { return upper_bound_; }\n    size_t size() const { return items_count_; }\n\n   private:\n    size_t upper_bound_ = 0;\n    size_t items_count_ = 0;\n    absl::FixedArray<size_t> items_;\n  };\n\n  absl::BitGen gen_;\n  std::deque<Bucket> buckets_;\n  size_t stream_size_ = 0;\n  size_t upper_bound_ = 0;\n  size_t lower_bound_ = 0;\n  size_t bucket_size_ = 524288;\n};\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/stream_shuffler_test.cc",
    "content": "#include \"utils/stream_shuffler.h\"\n\n#include <absl/container/flat_hash_set.h>\n#include <absl/random/random.h>\n#include <gtest/gtest.h>\n\n#include <set>\n#include <vector>\n\nnamespace lczero {\nnamespace training {\n\nclass StreamShufflerTest : public ::testing::Test {\n protected:\n  void SetUp() override { shuffler_.SetBucketSize(4); }\n\n  StreamShuffler shuffler_;\n};\n\nTEST_F(StreamShufflerTest, EmptyRangeReturnsNullopt) {\n  shuffler_.SetUpperBound(10);\n  shuffler_.SetLowerBound(10);\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, SingleItemRange) {\n  shuffler_.SetUpperBound(1);\n  shuffler_.SetLowerBound(0);\n\n  auto item = shuffler_.GetNextItem();\n  ASSERT_TRUE(item.has_value());\n  EXPECT_EQ(item.value(), 0);\n\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, BasicRangeGeneration) {\n  shuffler_.SetUpperBound(5);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> received;\n  for (int i = 0; i < 5; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 5);\n    EXPECT_TRUE(received.insert(item.value()).second);\n  }\n\n  EXPECT_EQ(received.size(), 5);\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, HeadAdvancesByBucketMultiples) {\n  shuffler_.SetUpperBound(4);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> received;\n  for (int i = 0; i < 4; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    received.insert(item.value());\n  }\n  EXPECT_EQ(received.size(), 4);\n\n  shuffler_.SetUpperBound(8);\n  for (int i = 0; i < 4; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 8);\n    EXPECT_TRUE(received.insert(item.value()).second);\n  }\n  EXPECT_EQ(received.size(), 8);\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, HeadAdvancesByNonMultiples) {\n  shuffler_.SetUpperBound(3);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> received;\n  for (int i = 0; i < 3; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    received.insert(item.value());\n  }\n\n  shuffler_.SetUpperBound(7);\n  for (int i = 0; i < 4; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 7);\n    EXPECT_TRUE(received.insert(item.value()).second);\n  }\n  EXPECT_EQ(received.size(), 7);\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, TailAdvancesByBucketMultiples) {\n  shuffler_.SetUpperBound(12);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> all_received;\n  for (int i = 0; i < 4; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 12);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  shuffler_.SetLowerBound(4);\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    EXPECT_GE(item.value(), 4);\n    EXPECT_LT(item.value(), 12);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  // Verify all items in range [4, 12) were eventually fetched\n  for (size_t i = 4; i < 12; ++i) {\n    EXPECT_TRUE(all_received.count(i) > 0)\n        << \"Item \" << i << \" was never fetched\";\n  }\n}\n\nTEST_F(StreamShufflerTest, TailAdvancesByNonMultiples) {\n  shuffler_.SetUpperBound(10);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> all_received;\n  for (int i = 0; i < 3; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 10);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  shuffler_.SetLowerBound(3);\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    EXPECT_GE(item.value(), 3);\n    EXPECT_LT(item.value(), 10);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  // Verify all items in range [3, 10) were eventually fetched\n  for (size_t i = 3; i < 10; ++i) {\n    EXPECT_TRUE(all_received.count(i) > 0)\n        << \"Item \" << i << \" was never fetched\";\n  }\n}\n\nTEST_F(StreamShufflerTest, BothBoundsSlideSimultaneously) {\n  shuffler_.SetUpperBound(10);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> all_received;\n  for (int i = 0; i < 5; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 10);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  shuffler_.SetUpperBound(15);\n  shuffler_.SetLowerBound(5);\n\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    EXPECT_GE(item.value(), 5);\n    EXPECT_LT(item.value(), 15);\n    EXPECT_TRUE(all_received.insert(item.value()).second);\n  }\n\n  // Verify all items in range [5, 15) were eventually fetched\n  for (size_t i = 5; i < 15; ++i) {\n    EXPECT_TRUE(all_received.count(i) > 0)\n        << \"Item \" << i << \" was never fetched\";\n  }\n}\n\nTEST_F(StreamShufflerTest, ComplexSlidingWindow) {\n  std::set<size_t> all_received;\n\n  shuffler_.SetUpperBound(6);\n  shuffler_.SetLowerBound(0);\n\n  for (int i = 0; i < 3; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    all_received.insert(item.value());\n  }\n\n  shuffler_.SetUpperBound(11);\n  for (int i = 0; i < 2; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n    all_received.insert(item.value());\n  }\n\n  shuffler_.SetLowerBound(2);\n  shuffler_.SetUpperBound(14);\n\n  std::set<size_t> final_received;\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    EXPECT_GE(item.value(), 2);\n    EXPECT_LT(item.value(), 14);\n    EXPECT_TRUE(final_received.insert(item.value()).second);\n  }\n\n  for (const auto& val : final_received) {\n    EXPECT_GE(val, 2);\n    EXPECT_LT(val, 14);\n  }\n}\n\nTEST_F(StreamShufflerTest, UniquenessAcrossMultipleBuckets) {\n  shuffler_.SetUpperBound(20);\n  shuffler_.SetLowerBound(0);\n\n  std::set<size_t> received;\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    EXPECT_GE(item.value(), 0);\n    EXPECT_LT(item.value(), 20);\n    EXPECT_TRUE(received.insert(item.value()).second);\n  }\n\n  EXPECT_EQ(received.size(), 20);\n}\n\nTEST_F(StreamShufflerTest, TailCatchesUpToHead) {\n  shuffler_.SetUpperBound(8);\n  shuffler_.SetLowerBound(0);\n\n  for (int i = 0; i < 3; ++i) {\n    auto item = shuffler_.GetNextItem();\n    ASSERT_TRUE(item.has_value());\n  }\n\n  shuffler_.SetLowerBound(8);\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n}\n\nTEST_F(StreamShufflerTest, ResetAllowsIterationRestart) {\n  shuffler_.SetUpperBound(5);\n  shuffler_.SetLowerBound(0);\n\n  // Exhaust all items\n  absl::flat_hash_set<size_t> first_round;\n  std::optional<size_t> item;\n  while ((item = shuffler_.GetNextItem()).has_value()) {\n    first_round.insert(item.value());\n  }\n\n  // Should have gotten all 5 items\n  EXPECT_EQ(first_round.size(), 5);\n\n  // Shuffler should be exhausted\n  EXPECT_EQ(shuffler_.GetNextItem(), std::nullopt);\n\n  // Reset the shuffler\n  shuffler_.Reset(0, 5);\n\n  // Should be able to get items again\n  absl::flat_hash_set<size_t> second_round;\n  int count = 0;\n  while ((item = shuffler_.GetNextItem()).has_value() && count < 10) {\n    second_round.insert(item.value());\n    count++;\n  }\n\n  // Should get all items again\n  EXPECT_EQ(second_round.size(), 5);\n\n  // Both rounds should contain the same set of items\n  EXPECT_EQ(first_round, second_round);\n}\n\n}  // namespace training\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/tensor.h",
    "content": "#pragma once\n\n#include <stdexcept>\n#include <string>\n#include <type_traits>\n#include <vector>\n\n#include \"absl/algorithm/container.h\"\n#include \"absl/container/fixed_array.h\"\n#include \"absl/types/span.h\"\n\nnamespace lczero {\n\n// Class that holds tensor which will be exposed through pybind11.\nclass TensorBase {\n public:\n  virtual ~TensorBase() = default;\n  virtual void* data() = 0;\n  virtual const void* data() const = 0;\n  virtual const std::vector<ssize_t>& shape() const = 0;\n  virtual const std::vector<ssize_t>& strides() const = 0;\n  virtual size_t element_size() const = 0;\n  virtual std::string py_format() const = 0;\n};\n\ntemplate <typename T>\nclass TypedTensor : public TensorBase {\n public:\n  TypedTensor(std::initializer_list<size_t> shape)\n      : data_(CalculateTotalSize(shape)), shape_(shape.begin(), shape.end()) {\n    // Calculate strides in row-major order (in bytes).\n    strides_.resize(shape_.size());\n    size_t total_size = 1;\n    for (int i = static_cast<int>(shape_.size()) - 1; i >= 0; --i) {\n      strides_[i] = total_size * sizeof(T);\n      total_size *= shape_[i];\n    }\n  }\n\n  void* data() override { return data_.data(); }\n\n  const void* data() const override { return data_.data(); }\n\n  const std::vector<ssize_t>& shape() const override { return shape_; }\n\n  const std::vector<ssize_t>& strides() const override { return strides_; }\n\n  size_t element_size() const override { return sizeof(T); }\n\n  std::string py_format() const override {\n    if constexpr (std::is_same_v<T, float>) {\n      return \"f\";\n    } else if constexpr (std::is_same_v<T, double>) {\n      return \"d\";\n    } else if constexpr (std::is_same_v<T, int32_t>) {\n      return \"i\";\n    } else if constexpr (std::is_same_v<T, int64_t>) {\n      return \"q\";\n    } else {\n      static_assert(std::is_same_v<T, void>, \"Unsupported tensor type\");\n    }\n  }\n\n  T& operator[](absl::Span<const ssize_t> dims) {\n    if (dims.size() != shape_.size()) {\n      throw std::invalid_argument(\n          \"Number of dimensions must match tensor rank\");\n    }\n    return data_[CalculateOffset(dims)];\n  }\n\n  const T& operator[](absl::Span<const ssize_t> dims) const {\n    if (dims.size() != shape_.size()) {\n      throw std::invalid_argument(\n          \"Number of dimensions must match tensor rank\");\n    }\n    return data_[CalculateOffset(dims)];\n  }\n\n  absl::Span<T> slice(absl::Span<const ssize_t> dims) {\n    if (dims.size() > shape_.size()) {\n      throw std::invalid_argument(\n          \"Number of dimensions cannot exceed tensor rank\");\n    }\n    return absl::Span<T>(data_.data() + CalculateOffset(dims),\n                         CalculateSliceSize(dims.size()));\n  }\n\n  absl::Span<const T> slice(absl::Span<const ssize_t> dims) const {\n    if (dims.size() > shape_.size()) {\n      throw std::invalid_argument(\n          \"Number of dimensions cannot exceed tensor rank\");\n    }\n    return absl::Span<const T>(data_.data() + CalculateOffset(dims),\n                               CalculateSliceSize(dims.size()));\n  }\n\n private:\n  size_t CalculateOffset(absl::Span<const ssize_t> dims) const {\n    return absl::c_inner_product(\n        dims, strides_, size_t{0}, std::plus<>{},\n        [](ssize_t dim, ssize_t stride) { return dim * stride / sizeof(T); });\n  }\n\n  size_t CalculateSliceSize(size_t dims_size) const {\n    return absl::c_accumulate(absl::MakeConstSpan(shape_).subspan(dims_size),\n                              size_t{1}, std::multiplies<size_t>{});\n  }\n\n  static size_t CalculateTotalSize(std::initializer_list<size_t> shape) {\n    return absl::c_accumulate(shape, size_t{1}, std::multiplies<size_t>{});\n  }\n\n  absl::FixedArray<T> data_;\n  std::vector<ssize_t> shape_;\n  std::vector<ssize_t> strides_;\n};\n\nusing TensorTuple = std::vector<std::unique_ptr<TensorBase>>;\n\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/tensor_test.cc",
    "content": "// ABOUTME: Unit tests for tensor classes and their data access methods.\n// ABOUTME: Tests construction, element access, slicing, and error conditions.\n\n#include \"utils/tensor.h\"\n\n#include <gtest/gtest.h>\n\nnamespace lczero {\nnamespace {\n\nTEST(TypedTensorTest, ConstructorAndBasicProperties) {\n  TypedTensor<float> tensor({2, 3, 4});\n\n  // Check shape\n  EXPECT_EQ(tensor.shape().size(), 3);\n  EXPECT_EQ(tensor.shape()[0], 2);\n  EXPECT_EQ(tensor.shape()[1], 3);\n  EXPECT_EQ(tensor.shape()[2], 4);\n\n  // Check strides (in bytes, row-major order)\n  EXPECT_EQ(tensor.strides().size(), 3);\n  EXPECT_EQ(tensor.strides()[0], 12 * sizeof(float));  // 3 * 4 elements\n  EXPECT_EQ(tensor.strides()[1], 4 * sizeof(float));   // 4 elements\n  EXPECT_EQ(tensor.strides()[2], 1 * sizeof(float));   // 1 element\n\n  // Check element size\n  EXPECT_EQ(tensor.element_size(), sizeof(float));\n\n  // Check py_format\n  EXPECT_EQ(tensor.py_format(), \"f\");\n\n  // Check data pointer is valid\n  EXPECT_NE(tensor.data(), nullptr);\n}\n\nTEST(TypedTensorTest, PyFormatForDifferentTypes) {\n  TypedTensor<float> float_tensor({2});\n  EXPECT_EQ(float_tensor.py_format(), \"f\");\n\n  TypedTensor<double> double_tensor({2});\n  EXPECT_EQ(double_tensor.py_format(), \"d\");\n\n  TypedTensor<int32_t> int32_tensor({2});\n  EXPECT_EQ(int32_tensor.py_format(), \"i\");\n\n  TypedTensor<int64_t> int64_tensor({2});\n  EXPECT_EQ(int64_tensor.py_format(), \"q\");\n}\n\nTEST(TypedTensorTest, ElementAccess) {\n  TypedTensor<int> tensor({2, 3});\n\n  // Set some values\n  tensor[{0, 0}] = 10;\n  tensor[{0, 1}] = 11;\n  tensor[{0, 2}] = 12;\n  tensor[{1, 0}] = 20;\n  tensor[{1, 1}] = 21;\n  tensor[{1, 2}] = 22;\n\n  // Check values\n  EXPECT_EQ((tensor[{0, 0}]), 10);\n  EXPECT_EQ((tensor[{0, 1}]), 11);\n  EXPECT_EQ((tensor[{0, 2}]), 12);\n  EXPECT_EQ((tensor[{1, 0}]), 20);\n  EXPECT_EQ((tensor[{1, 1}]), 21);\n  EXPECT_EQ((tensor[{1, 2}]), 22);\n}\n\nTEST(TypedTensorTest, ConstElementAccess) {\n  TypedTensor<int> tensor({2, 2});\n  tensor[{0, 0}] = 1;\n  tensor[{0, 1}] = 2;\n  tensor[{1, 0}] = 3;\n  tensor[{1, 1}] = 4;\n\n  const auto& const_tensor = tensor;\n  EXPECT_EQ((const_tensor[{0, 0}]), 1);\n  EXPECT_EQ((const_tensor[{0, 1}]), 2);\n  EXPECT_EQ((const_tensor[{1, 0}]), 3);\n  EXPECT_EQ((const_tensor[{1, 1}]), 4);\n}\n\nTEST(TypedTensorTest, SliceAccess) {\n  TypedTensor<int> tensor({2, 3, 4});\n\n  // Fill with test data\n  for (int i = 0; i < 2; ++i) {\n    for (int j = 0; j < 3; ++j) {\n      for (int k = 0; k < 4; ++k) {\n        tensor[{i, j, k}] = i * 100 + j * 10 + k;\n      }\n    }\n  }\n\n  // Test 1D slice (fix first dimension)\n  auto slice1d = tensor.slice({1});\n  EXPECT_EQ(slice1d.size(), 12);  // 3 * 4 elements\n  EXPECT_EQ(slice1d[0], 100);     // tensor[{1, 0, 0}]\n  EXPECT_EQ(slice1d[4], 110);     // tensor[{1, 1, 0}]\n\n  // Test 2D slice (fix first two dimensions)\n  auto slice2d = tensor.slice({0, 1});\n  EXPECT_EQ(slice2d.size(), 4);  // 4 elements\n  EXPECT_EQ(slice2d[0], 10);     // tensor[{0, 1, 0}]\n  EXPECT_EQ(slice2d[1], 11);     // tensor[{0, 1, 1}]\n  EXPECT_EQ(slice2d[2], 12);     // tensor[{0, 1, 2}]\n  EXPECT_EQ(slice2d[3], 13);     // tensor[{0, 1, 3}]\n\n  // Test full tensor slice (no dimensions fixed)\n  auto full_slice = tensor.slice({});\n  EXPECT_EQ(full_slice.size(), 24);  // 2 * 3 * 4 elements\n}\n\nTEST(TypedTensorTest, ConstSliceAccess) {\n  TypedTensor<int> tensor({2, 2});\n  tensor[{0, 0}] = 1;\n  tensor[{0, 1}] = 2;\n  tensor[{1, 0}] = 3;\n  tensor[{1, 1}] = 4;\n\n  const auto& const_tensor = tensor;\n  auto slice = const_tensor.slice({0});\n  EXPECT_EQ(slice.size(), 2);\n  EXPECT_EQ(slice[0], 1);\n  EXPECT_EQ(slice[1], 2);\n}\n\nTEST(TypedTensorTest, ElementAccessWrongDimensions) {\n  TypedTensor<int> tensor({2, 3});\n\n  EXPECT_THROW((tensor[{0}]), std::invalid_argument);\n  EXPECT_THROW((tensor[{0, 1, 2}]), std::invalid_argument);\n}\n\nTEST(TypedTensorTest, SliceAccessTooManyDimensions) {\n  TypedTensor<int> tensor({2, 3});\n\n  EXPECT_THROW((tensor.slice({0, 1, 2})), std::invalid_argument);\n}\n\nTEST(TypedTensorTest, OneDimensionalTensor) {\n  TypedTensor<float> tensor({5});\n\n  EXPECT_EQ(tensor.shape().size(), 1);\n  EXPECT_EQ(tensor.shape()[0], 5);\n  EXPECT_EQ(tensor.strides()[0], sizeof(float));\n\n  tensor[{0}] = 1.0f;\n  tensor[{4}] = 5.0f;\n\n  EXPECT_EQ((tensor[{0}]), 1.0f);\n  EXPECT_EQ((tensor[{4}]), 5.0f);\n\n  auto slice = tensor.slice({});\n  EXPECT_EQ(slice.size(), 5);\n}\n\n}  // namespace\n}  // namespace lczero"
  },
  {
    "path": "csrc/utils/thread_pool.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <deque>\n#include <exception>\n#include <functional>\n#include <future>\n#include <iostream>\n#include <stop_token>\n#include <thread>\n\n#include \"absl/functional/any_invocable.h\"\n#include \"absl/synchronization/mutex.h\"\n\nnamespace lczero {\n\nstruct ThreadPoolOptions {\n  // If true, starts new thread when task is enqueued and no threads are idle.\n  bool grow_automatically = false;\n};\n\nclass ThreadPool {\n public:\n  ThreadPool(size_t initial_threads = 0,\n             const ThreadPoolOptions& options = ThreadPoolOptions(),\n             std::stop_source stop_source = std::stop_source());\n\n  // Blocks until all tasks are completed and threads are joined.\n  ~ThreadPool();\n\n  // Returns the stop_token for this thread pool.\n  std::stop_token stop_token() const;\n\n  // Enqueues a task for execution and returns a std::future.\n  // If the provided function accepts a std::stop_token as its first argument,\n  // one will be passed to it from the thread pool's stop source.\n  template <typename F, typename... Args>\n  auto Enqueue(F&& f, Args&&... args) {\n    // This lambda captures the common queuing logic.\n    auto enqueue_common = [&](auto&& task_to_enqueue, auto&& future_to_return) {\n      {\n        absl::MutexLock lock(&mutex_);\n        running_tasks_ += 1;\n        while (options_.grow_automatically &&\n               running_tasks_ >= threads_.size()) {\n          StartWorkerThread();\n        }\n        pending_tasks_.emplace_back(\n            [task = std::move(task_to_enqueue)]() mutable { task(); });\n        work_available_.Signal();\n      }\n      return future_to_return;\n    };\n\n    if constexpr (std::is_invocable_v<F, std::stop_token, Args...>) {\n      using ReturnType = std::invoke_result_t<F, std::stop_token, Args...>;\n      std::packaged_task<ReturnType()> task(\n          std::bind(std::forward<F>(f), stop_source_.get_token(),\n                    std::forward<Args>(args)...));\n      std::future<ReturnType> future = task.get_future();\n      return enqueue_common(std::move(task), std::move(future));\n    } else {\n      using ReturnType = std::invoke_result_t<F, Args...>;\n      std::packaged_task<ReturnType()> task(\n          std::bind(std::forward<F>(f), std::forward<Args>(args)...));\n      std::future<ReturnType> future = task.get_future();\n      return enqueue_common(std::move(task), std::move(future));\n    }\n  }\n\n  // Waits for all tasks to complete.\n  void WaitAll();\n\n  // Waits for at least one thread to become available, i.e. no tasks are\n  // pending, and number of running tasks is less than the number of threads.\n  void WaitForAvailableThread();\n\n  // Waits until the number of queued but not yet started tasks is below\n  // the specified threshold.\n  void WaitForPendingTasksBelow(size_t threshold);\n\n  // Number of tasks that are not yet started.\n  size_t num_pending_tasks() const;\n\n  // Number of tasks that are currently running.\n  size_t num_running_tasks() const;\n\n  // Number of worker threads (busy or not).\n  size_t num_threads() const;\n\n  // Signal workers to terminate and join all threads.\n  void Shutdown();\n\n private:\n  void WorkerLoop();\n  void WorkerEntryPoint();\n  void StartWorkerThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);\n  bool AllTasksCompletedCond() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n    return pending_tasks_.empty() && running_tasks_ == 0;\n  }\n  bool ThreadAvailableCond() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n    return pending_tasks_.empty() && running_tasks_ < threads_.size();\n  }\n\n  ThreadPool(const ThreadPool&) = delete;\n  ThreadPool& operator=(const ThreadPool&) = delete;\n  ThreadPool(ThreadPool&&) = delete;\n  ThreadPool& operator=(ThreadPool&&) = delete;\n\n  ThreadPoolOptions options_;\n  mutable absl::Mutex mutex_;\n  absl::CondVar work_available_;\n  absl::CondVar work_done_;\n\n  std::stop_source stop_source_;\n  std::vector<std::jthread> threads_ ABSL_GUARDED_BY(mutex_);\n  std::deque<absl::AnyInvocable<void()>> pending_tasks_ ABSL_GUARDED_BY(mutex_);\n  size_t running_tasks_ ABSL_GUARDED_BY(mutex_) = 0;\n};\n\ninline ThreadPool::ThreadPool(size_t initial_threads,\n                              const ThreadPoolOptions& options,\n                              std::stop_source stop_source)\n    : options_(options), stop_source_(std::move(stop_source)) {\n  absl::MutexLock lock(&mutex_);\n  for (size_t i = 0; i < initial_threads; ++i) {\n    StartWorkerThread();\n  }\n}\n\ninline ThreadPool::~ThreadPool() { Shutdown(); }\n\ninline std::stop_token ThreadPool::stop_token() const {\n  return stop_source_.get_token();\n}\n\ninline void ThreadPool::WorkerLoop() {\n  while (true) {\n    absl::AnyInvocable<void()> task;\n    {\n      absl::MutexLock lock(&mutex_);\n      while (!stop_source_.stop_requested() && pending_tasks_.empty()) {\n        work_available_.Wait(&mutex_);\n      }\n      if (stop_source_.stop_requested() && pending_tasks_.empty()) return;\n      task = std::move(pending_tasks_.front());\n      pending_tasks_.pop_front();\n    }\n\n    std::move(task)();\n\n    {\n      absl::MutexLock lock(&mutex_);\n      running_tasks_ -= 1;\n      work_done_.SignalAll();\n      if (!pending_tasks_.empty()) work_available_.Signal();\n    }\n  }\n}\n\ninline void ThreadPool::WorkerEntryPoint() {\n  try {\n    WorkerLoop();\n  } catch (const std::exception& exception) {\n    std::cerr << \"ThreadPool worker exited due to uncaught exception: \"\n              << exception.what() << std::endl;\n    throw;\n  } catch (...) {\n    std::cerr << \"ThreadPool worker exited due to unknown exception.\"\n              << std::endl;\n    throw;\n  }\n}\n\ninline void ThreadPool::WaitAll() {\n  absl::MutexLock lock(&mutex_);\n  while (!AllTasksCompletedCond()) {\n    work_done_.Wait(&mutex_);\n  }\n}\n\ninline void ThreadPool::WaitForAvailableThread() {\n  absl::MutexLock lock(&mutex_);\n  while (!ThreadAvailableCond()) {\n    work_done_.Wait(&mutex_);\n  }\n}\n\ninline void ThreadPool::WaitForPendingTasksBelow(size_t threshold) {\n  absl::MutexLock lock(&mutex_);\n  while (pending_tasks_.size() >= threshold) {\n    work_done_.Wait(&mutex_);\n  }\n}\n\ninline void ThreadPool::StartWorkerThread()\n    ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {\n  threads_.emplace_back(&ThreadPool::WorkerEntryPoint, this);\n}\n\ninline size_t ThreadPool::num_pending_tasks() const {\n  absl::MutexLock lock(&mutex_);\n  return pending_tasks_.size();\n}\n\ninline size_t ThreadPool::num_running_tasks() const {\n  absl::MutexLock lock(&mutex_);\n  return std::max(running_tasks_, threads_.size());\n}\n\ninline size_t ThreadPool::num_threads() const {\n  absl::MutexLock lock(&mutex_);\n  return threads_.size();\n}\n\ninline void ThreadPool::Shutdown() {\n  {\n    absl::MutexLock lock(&mutex_);\n    if (!stop_source_.stop_requested()) {\n      stop_source_.request_stop();\n    }\n    work_available_.SignalAll();\n    work_done_.SignalAll();\n  }\n  threads_.clear();\n}\n\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/utils/training_data_printer.cc",
    "content": "#include \"utils/training_data_printer.h\"\n\n#include <absl/strings/str_format.h>\n\n#include <algorithm>\n#include <iostream>\n\n#include \"chess/board.h\"\n#include \"neural/decoder.h\"\n#include \"trainingdata/reader.h\"\n\nnamespace lczero {\nnamespace training {\n\nvoid PrintFloatArray(const float* data, size_t size, absl::string_view name,\n                     int64_t per_line) {\n  per_line = std::max<int64_t>(1, per_line);\n  std::cout << \"  \" << name << \":\\n\";\n  for (size_t i = 0; i < size; ++i) {\n    if (i % per_line == 0) {\n      std::cout << \"    [\" << absl::StrFormat(\"%4zu\", i) << \"]: \";\n    }\n    std::cout << absl::StrFormat(\"% .6g\", data[i]);\n    if ((i + 1) % per_line == 0 || i + 1 == size) {\n      std::cout << \"\\n\";\n    } else {\n      std::cout << \", \";\n    }\n  }\n}\n\nvoid PrintUint64Array(const uint64_t* data, size_t size, absl::string_view name,\n                      int64_t per_line) {\n  per_line = std::max<int64_t>(1, per_line);\n  std::cout << \"  \" << name << \":\\n\";\n  for (size_t i = 0; i < size; ++i) {\n    if (i % per_line == 0) {\n      std::cout << \"    [\" << absl::StrFormat(\"%3zu\", i) << \"]: \";\n    }\n    std::cout << absl::StrFormat(\"0x%016x\", data[i]);\n    if ((i + 1) % per_line == 0 || i + 1 == size) {\n      std::cout << \"\\n\";\n    } else {\n      std::cout << \", \";\n    }\n  }\n}\n\nstd::string DecodeInvarianceInfo(uint8_t invariance_info) {\n  return absl::StrFormat(\n      \"flip=%d, mirror=%d, transpose=%d, best_move_proven=%d, \"\n      \"max_length=%d, adjudicated=%d, rescorer_deleted=%d, side_to_move=%d\",\n      invariance_info & 0x1, (invariance_info >> 1) & 0x1,\n      (invariance_info >> 2) & 0x1, (invariance_info >> 3) & 0x1,\n      (invariance_info >> 4) & 0x1, (invariance_info >> 5) & 0x1,\n      (invariance_info >> 6) & 0x1, (invariance_info >> 7) & 0x1);\n}\n\nstd::string TrainingDataToFen(const FrameType& entry) {\n  InputPlanes planes = PlanesFromTrainingData(entry);\n  ChessBoard board;\n  int rule50 = 0;\n  int gameply = 0;\n  PopulateBoard(\n      static_cast<pblczero::NetworkFormat::InputFormat>(entry.input_format),\n      planes, &board, &rule50, &gameply);\n  std::string fen = BoardToFen(board);\n  fen += \" \" + std::to_string(rule50);\n  fen += \" \" + std::to_string((gameply / 2) + 1);\n  return fen;\n}\n\nvoid PrintTrainingDataEntry(const FrameType& entry,\n                            absl::string_view header_text,\n                            int64_t float_per_line, int64_t plane_per_line) {\n  std::cout << header_text << \"\\n\";\n  std::cout << \"  FEN: \" << TrainingDataToFen(entry) << \"\\n\";\n  std::cout << \"  version: \" << entry.version << \"\\n\";\n  std::cout << \"  input_format: \" << entry.input_format << \"\\n\";\n  std::cout << \"  castling_us_ooo: \" << static_cast<int>(entry.castling_us_ooo)\n            << \"\\n\";\n  std::cout << \"  castling_us_oo: \" << static_cast<int>(entry.castling_us_oo)\n            << \"\\n\";\n  std::cout << \"  castling_them_ooo: \"\n            << static_cast<int>(entry.castling_them_ooo) << \"\\n\";\n  std::cout << \"  castling_them_oo: \"\n            << static_cast<int>(entry.castling_them_oo) << \"\\n\";\n  std::cout << \"  side_to_move_or_enpassant: \"\n            << static_cast<int>(entry.side_to_move_or_enpassant) << \"\\n\";\n  std::cout << \"  rule50_count: \" << static_cast<int>(entry.rule50_count)\n            << \"\\n\";\n  std::cout << \"  invariance_info: \" << static_cast<int>(entry.invariance_info)\n            << \" (\" << DecodeInvarianceInfo(entry.invariance_info) << \")\\n\";\n  std::cout << \"  dummy: \" << static_cast<int>(entry.dummy) << \"\\n\";\n  std::cout << \"  root_q: \" << entry.root_q << \"\\n\";\n  std::cout << \"  best_q: \" << entry.best_q << \"\\n\";\n  std::cout << \"  root_d: \" << entry.root_d << \"\\n\";\n  std::cout << \"  best_d: \" << entry.best_d << \"\\n\";\n  std::cout << \"  root_m: \" << entry.root_m << \"\\n\";\n  std::cout << \"  best_m: \" << entry.best_m << \"\\n\";\n  std::cout << \"  plies_left: \" << entry.plies_left << \"\\n\";\n  std::cout << \"  result_q: \" << entry.result_q << \"\\n\";\n  std::cout << \"  result_d: \" << entry.result_d << \"\\n\";\n  std::cout << \"  played_q: \" << entry.played_q << \"\\n\";\n  std::cout << \"  played_d: \" << entry.played_d << \"\\n\";\n  std::cout << \"  played_m: \" << entry.played_m << \"\\n\";\n  std::cout << \"  orig_q: \" << entry.orig_q << \"\\n\";\n  std::cout << \"  orig_d: \" << entry.orig_d << \"\\n\";\n  std::cout << \"  orig_m: \" << entry.orig_m << \"\\n\";\n  std::cout << \"  visits: \" << entry.visits << \"\\n\";\n  std::cout << \"  played_idx: \" << entry.played_idx << \"\\n\";\n  std::cout << \"  best_idx: \" << entry.best_idx << \"\\n\";\n  std::cout << \"  policy_kld: \" << entry.policy_kld << \"\\n\";\n  PrintFloatArray(entry.probabilities, std::size(entry.probabilities),\n                  \"probabilities\", float_per_line);\n  PrintUint64Array(entry.planes, std::size(entry.planes), \"planes\",\n                   plane_per_line);\n  std::cout << std::flush;\n}\n\n}  // namespace training\n}  // namespace lczero\n"
  },
  {
    "path": "csrc/utils/training_data_printer.h",
    "content": "#ifndef LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_\n#define LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_\n\n#include <absl/strings/string_view.h>\n\n#include <cstddef>\n#include <cstdint>\n#include <string>\n\n#include \"loader/frame_type.h\"\n#include \"trainingdata/trainingdata_v6.h\"\n\nnamespace lczero {\nnamespace training {\n\n// Prints a float array with configurable number of values per line.\nvoid PrintFloatArray(const float* data, size_t size, absl::string_view name,\n                     int64_t per_line);\n\n// Prints a uint64 array with configurable number of values per line.\nvoid PrintUint64Array(const uint64_t* data, size_t size, absl::string_view name,\n                      int64_t per_line);\n\n// Decodes the invariance_info byte into a human-readable string.\nstd::string DecodeInvarianceInfo(uint8_t invariance_info);\n\n// Converts a V6TrainingData entry to FEN (Forsyth-Edwards Notation).\nstd::string TrainingDataToFen(const V6TrainingData& entry);\n\n// Prints a V6TrainingData entry with a custom header and formatting options.\nvoid PrintTrainingDataEntry(const FrameType& entry,\n                            absl::string_view header_text,\n                            int64_t float_per_line, int64_t plane_per_line);\n\n}  // namespace training\n}  // namespace lczero\n\n#endif  // LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_\n"
  },
  {
    "path": "docs/README.md",
    "content": "# Running \"new\" training pipeline\n\nNote that the code is still in active development, so things change a lot.\nThe current document was last updated on 2025-11-30.\n\n## Building\n\nThe new training pipeline is located in `src/` (Python part) and `csrc/` (C++\npart).\n\n* Python code uses `uv`. Install it as described in the\n  [uv installation guide](https://docs.astral.sh/uv/#installation).\n* Many steps are run via `just`. `just` is just a glorified shell script runner.\n  So either look into [`Justfile`](../justfile) or install `just` as described\n  in the [just installation guide](https://github.com/casey/just#installation).\n* You'll need a recent protobuf compiler (`protoc`).\n* You'll need a C++ compiler. In this example we use `clang`.\n\n```bash\ncd <repo-root>\nuv python install 3.12\nuv venv\nuv sync\nuv pip install meson ruff\ngit submodule update --init --recursive\nCXX=clang++ CC=clang uv run meson setup build/release\\\n   --buildtype=release --native-file=native.ini\nuv run meson configure build/release \\\n   -Dcpp_args='-Wno-error=deprecated-declarations'\njust build\ncd src/lczero_training\nln -sfT ../../build/release/_lczero_training.cpython-*-x86_64-linux-gnu.so _lczero_training.so\njust build-proto\n```\n\n## Training a model\n\nTo train a model you need:\n\n* Training data\n* A configuration file\n* Create a checkpoint.\n* Run the pipeline.\n\n### Training data\n\nUnlike the old training pipeline, the new one doesn't need .tar files to be\nunpacked. While it does support plain `.gz` chunk files, it's not efficient as\nit stores each individual file name in memory. So instead, use `.tar` files,\nthe tool can index and seek inside them.\n\nThe tool watches a directory (and its subdirectories) for new files.\n\nTerms used:\n\n* **Chunk**/Game: A single training game, individual `.gz` file.\n* **Chunk source**: A file (`.tar` or `.gz`) containing multiple chunks.\n* **Frame**/Record/Position: A single training position inside a chunk.\n* **Training tensor**: A single batch of inputs/outputs encoded in NN format for\n  one training step.\n\nIncoming data comes as chunks, but for the training we need frames from\ndifferent games.\n\n### Note on RL vs SL training\n\nThe tool supports both supervised learning (SL) and reinforcement learning (RL).\nHere is overview of the configuration differences:\n\n| RL Training                                                                                                                                                                                                                                                                       | SL Training                                                                                                                                                   |\n| --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| `shuffling_chunk_pool` has relatively small`chunk_pool_size`, which would be used as a sliding window.                                                                                                                                                                            | `chunk_pool_size` should be larger than all data, so that all data is used for training.                                                                      |\n| `training.schedule.chunks_per_network` is non zero — that's how many new chunks to wait for before starting a new epoch.                                                                                                                                                          | `chunks_per_network` is zero. Once an epoch is done, it starts a new one immediately.                                                                         |\n| RL currently uses \"hanse sampling\", where currently the entire chunk is loaded and rescored to use just one position from it. The reservoir `shuffling_frame_sampler` is not used in this case. This is currently slow (until we implement caching), so it limits the throughput. | For SL, it better to use two stage sampling: `shuffling_chunk_pool` in non-hanse mode, and then `shuffling_frame_sampler` to shuffle positions within chunks. |\n\n### Creating a checkpoint\n\nTo create a fresh checkpoint, you'll need `model` and `training` sections of the\nconfiguration file to be filled.\n\nThen run:\n\n```bash\nuv run lc0-init --config <your_config>.textproto --lczero_model <model>.pb.gz\n```\n\nThe `--lczero_model` parameter is optional. If not given, the network is\ninitialized with random weights.\n\n### Training\n\nRun:\n\n```bash\nCUDA_VISIBLE_DEVICES=0 uv run lc0-tui --config <your_config>.textproto --logfile train.log\n```\n\nNotes:\n\n* There are log files both in `tui` and in the configuration file. They\nare slightly different, TUI log usually is more useful.\n* In the tool, you can press `q` to quit, and `Ctrl+p` to get a command palette.\nThere, you have one useful command: \"Start training immediately\".\n* For multi-GPU training, ensure your batch size is divisible by the number of GPUs.\n* The overfit utility (`uv run overfit`) does not support multi-GPU. Set `CUDA_VISIBLE_DEVICES` to use only one GPU when running overfit.\n* Also note that TUI is 100% vibe coded, so you'll see lots of mocks in the UI.\n:-P\n\n## Tools\n\nThe repository consists of set of tools, mostly written in Python, but some are\nin C++. To run a Python tool, use `uv run <tool>`. C++ tools are binaries in\n`build/release`. Most tools need a configuration file as a parameter (see\nbelow).\n\n### Python tools\n\n| Tool                   | Description                                                                                                  |\n| ---------------------- | ------------------------------------------------------------------------------------------------------------ |\n| lc0-daemon             | The main training daemon. It acts as a JSONL server, so it's not usable directly from command line yet.      |\n| lc0-tui                | A terminal user interface that runs the training daemon. Here is what you have to run.                       |\n| lc0-init               | Initializes a new training run/checkpoint.                                                                   |\n| lc0-migrate-checkpoint | Migrates JAX/Orbax checkpoint after model/training configuration changes.                                    |\n| lc0-overfit            | Runs an overfitting test: takes one batch from the data loader and repeatedly trains on it                   |\n| lc0-eval               | Evals batches from the data loader on a given checkpoint, can dump inputs/outputs in various formats.        |\n| lc0-leela2jax          |                                                                                                              |\n| lc0-describe           |                                                                                                              |\n| lc0-test-dataloader    |                                                                                                              |\n| lc0-tune-lr            | Trains on exponentially increasing learning rate, and outputs losses into csv file. Useful for picking a LR. |\n| lc0-backfill-metrics   | Loads older checkpoints computes metrics for them, and exports them to tensorboard.                          |\n| lc0-train              | Trains a single epoch (doesn't save or export the model though). Used for benchmarking.                      |\n| lc0-weights            | Manipulates weight files: arithmetic operations, grafting components, format conversion. See [weights_tool.md](weights_tool.md). |\n\n### C++ tools\n\n| Tool                         | Description                                               |\n| ---------------------------- | --------------------------------------------------------- |\n| rescore_chunk                | Runs rescorer on a single chunk                           |\n| startpos_policy_distribution |                                                           |\n| result_distribution          |                                                           |\n| filter_chunks                |                                                           |\n| dump_chunk                   | Dumps the content of a chunk file for debugging purposes. |\n\n## Configuration\n\nThe configuration is a text protobuf file, with the following sections:\n\n| Section     | Description                                    |\n| ----------- | ---------------------------------------------- |\n| data_loader | Configuration for the data loader.             |\n| model       | Model architecture configuration.              |\n| training    | Configuration for the training configuration.  |\n| metrics     | Metrics to export into tensorboard.            |\n| export      | Configuration for exporting the trained model. |\n\nAlso it has `log_filename` field (where to write the log) and `name` field (must\nbe `little-teapot`).\n\nIt's recommended to use existing configuration files as a starting point.\n\n### Data loader configuration\n\nData loader is a pipeline that consists of pluggable stages. Here are stages\nthat are currently implemented:\n\n| Stage type                | Description                                                                                      | Input        | Output                |\n| ------------------------- | ------------------------------------------------------------------------------------------------ | ------------ | --------------------- |\n| `file_path_provider`      | Watches a directory for existing and new files.                                                  | None         | Filenames             |\n| `chunk_source_reader`     | Reads and indexes chunk source files (`.tar` or `.gz`).                                          | Filenames    | ChunkSources          |\n| `chunk_source_splitter`   | Splits chunk sources into smaller chunk sources give the proportion (used for test/train split). | ChunkSources | multiple ChunkSources |\n| `shuffling_chunk_pool`    | Accumulates chunk sources and outputs chunks in shuffled order                                   | ChunkSources | Chunks                |\n| `simple_chunk_extractor`  | Unpacks chunks from chunk sources                                                                | ChunkSources | Chunks                |\n| `chunk_rescorer`          | Rescores chunks                                                                                  | Chunks       | Chunks                |\n| `chunk_unpacker`          | Extracts positions from chunks                                                                   | Chunks       | Frames                |\n| `shuffling_frame_sampler` | Outputs frames in shuffled order                                                                 | Frames       | Frames                |\n| `tensor_generator`        | Converts frames into training batches in numpy tensor format                                     | Frames       | Training Tensors      |\n\nThe pipeline ends with one or more outputs, which provide tuples of batched\ntensors for training.\n\n> [!NOTE]\n> The current format of the training batch is\n>\n> * `inputs`: float32 tensor of shape `[batch_size, 112, 8, 8]`\n> * `policy_target`: float32 tensor of shape `[batch_size, 1862]`\n> * `value_target`: float32 tensor of shape `[batch_size, 6, 3]`, where 6 rows\n>    are sources of the value (`result`, `best`, `played`, `orig`, `root` and\n>    `st`), and 3 columns are (`q` (w-l), `draw`, `movesleft`).\n\nEvery stage must have an unique name (may or not be the same as the stage type),\nand arbitrary number of inputs (depending on the stage type; most have one\ninput).\n\nHere is the structure of the data loader configuration:\n\n```textproto\nstage {\n  name: \"file_provider\"\n  file_path_provider {\n    # ...\n  }\n}\nstage {\n  name: \"loader\"\n  input: \"file_provider\"\n  chunk_source_reader {\n    # ...\n    output { name: \"myoutput\" }\n  }\n}\nstage {\n  name: \"chunk_shuffler\"\n  input: \"loader.myoutput\"\n  shuffling_chunk_pool {\n    # ...\n  }\n}\n# ...\nstage {\n  name: \"tensor_gen\"\n  input: \"sampler\"\n  tensor_generator {\n    batch_size: 256\n    # ...\n  }\n}\noutput: \"tensor_gen\"  # unnamed output\noutput: \"test:test_tensor_gen\"  # named output\n```\n\n#### Stage output configuration\n\nEvery stage provides one or more outputs. The configuration of the output is like this (all fields optional):\n\n```textproto\noutput {\n  name: \"myoutput\"\n  queue_capacity: 8  # default: 4\n  overflow_behavior: BLOCK\n}\n```\n\n* By default, outputs are not named, but you can name them.\n* Higher `queue_capacity` allows you to \"pre-cache\" data, so that when the rate\n  of the producer stage is spiky, the pipeline is not blocked. On the other\n  hand, the data in the queue may be \"stale\" (i.e. when you train a new network,\n  the data in the queue is still for the old network).\n* `overflow_behavior` controls what happens when the output queue is full:\n  * `BLOCK`: default and what's needed for most stages. The producer stage is\n    blocked until there is space in the queue.\n  * `DROP_NEW` and `KEEP_NEWEST` drops the data from the queue (either the\n    incoming data, or the oldest data in the queue). These are useful e.g. for\n    auxiliary output of a stage (e.g. validation), so that the auxiliary\n    pipeline doesn't block the main pipeline.\n\n### Stage configurations\n\n#### file_path_provider\n\nWatches a directory for existing and new files. First it sends all existing\nfiles, then sends special \"Initial Scan Done\" event, and then watches for new\nfiles.\n\n#### chunk_source_loader\n\nTakes the filenames from the input, and loads them as chunk sources. Skips files\nwhich are not chunk sources.\n\n* `frame_format`: `V6TrainingData` (default) or `V7TrainingData`.\n\n#### shuffling_chunk_pool\n\nShuffling chunk pool is the central part of the data loader. In most cases,\nit is the only stage responsible for shuffling the data. In some cases, you may\nwant to have secondary shuffling_frame_sampler after it (e.g. for SL training).\n\nEvery chunk source has a \"sort key\" (currently, it's the file name without\npath). It's needed to determine the order of chunks to use for the sliding\nwindow.\n\n* `chunk_pool_size`: The size of the training window, in number of chunks. Even\n  when there are not enough chunks yet, the stage will output chunks from what\n  it has. It will not start producing data until the \"initial scan done\" event\n  is received from the file_path_provider.\n  * For RL training, typical values are 250k to 5M.\n  * For SL training, it should be larger than all data, so that all data is used\n    for training."
  },
  {
    "path": "docs/architecture.md",
    "content": "# Architecture Overview\n\nThe document outlines the architecture of the new Leela Chess Zero training\nsystem. The training process involves a Reinforcement Learning (RL) pipeline\nwhere new training data is continuously generated.\n\nThe old script required fresh starts to train a new network on fresher data.\nFurthermore, the use of TensorFlow involved a very long model compilation time\n(approx. 1 hour), which dominated the actual training time for a single epoch.\n\nThe new script will be a single, long-running Python application that automates\nthe entire cycle:\n\n1. Monitors for a sufficient amount of new training data.\n2. Triggers and executes the training of a new network.\n3. Exports the trained network for use.\n\nThe core training loop will be implemented in JAX. The data\n[loading and preprocessing pipeline](loader.md) is a C++ code exposed to Python\nvia pybind11. This C++ library is internally multi-threaded but exposes a simple\nAPI to Python (GetNextBatch(), GetStats()).\n\nWe'd like to have a fancy TUI dashboard to monitor the loading and training\nprocess.\n\n## TrainingDaemon\n\nThe main (in terms of importance) class of the training code is\n`TrainingDaemon`, which:\n\n* Is a jsonl server operating through stdin/stdout, implementing the protocol\n  described in [JSONL IPC Protocol](jsonl.md).\n  * Receives a command to start training with the location of the config file.\n    * Other commands are possible in the future.\n  * Sends periodic progress notifications.\n* Owns the data loader.\n* Waits for the data loader to ingest enough new chunks.\n* Starts the training loop when enough data is available.\n* Finalizes and uploads the trained network.\n\n## Configuration\n\nThe configuration is a large nested dataclass structure, which covers:\n\n* [Data loader](../src/lczero_training/config/data_loader_config.py) to be\n  passed to the C++ data loader.\n* Information for the training daemon, e.g. how many chunks to wait for\n  before starting the training.\n* Model definition, for model builder.\n* Training parameters, such as batch size, number of epochs, etc.\n* Export parameters, such as the path to export the trained model to.\n\nFrom user perspective, the configuration is a YAML file, which is parsed\ninto the dataclass structure.\n\n## Data Loader\n\nThe internals of the data loader are described in detail in [Data Loader](loader.md).\n\nFrom python perspective, it has the following interface:\n\n* Constructor takes a `DataLoaderConfig` dataclass.\n* `GetNextBatch()` returns a tuple of buffer-protocol-compliant tensors.\n  * Later it will have a parameter that specifies wether we need training, test\n    or validation batch.\n* `GetStats()` returns a DataClass (exact structure TBD) with the current\n  statistics of the data loader.\n\n## TUI\n\nTUI is a separate frontend app, implemented as `TrainingTuiApp`, which runs\n`TrainingDaemon` as a subprocess (and communicates through jsonl via\nstdin/stdout).\n\n* TUI is `Textual`-based app.\n* Located in `src/lczero_training/tui/`.\n* UI ideas are described in [TUI](tui.md).\n* TUI will have a log pane which shows stderr of `TrainingDaemon`.\n"
  },
  {
    "path": "docs/checkpoint_migration.md",
    "content": "# Checkpoint Migration\n\nWhen part of the model or training setup changes, JAX training state checkpoints\nmay become incompatible with the new setup.\n\nFor this, we provide a utility to help migrate checkpoints to the new setup.\n\nThe underlying implementation lives in\n`src/lczero_training/training/migrate_checkpoint.py` and is exposed via the CLI\ncommand `migrate-checkpoint` registered in `pyproject.toml`.\n\n## Command line arguments\n\n* `--config`: Path to the RootConfig textproto config. `model`/`training`\n  sections of this config will be used to initialize the new training state.\n  `checkpoint` is the location of the old checkpoint to migrate.\n* `--new_checkpoint`: Path to save the new checkpoint to. If not set, the tool\n  only checks whether the migration rules fully cover the differences between\n  the old and new training states.\n* `--overwrite`: If set, allows overwriting existing checkpoint. Also in this\n  case, if `--new_checkpoint` is not set, old checkpoint is used as the new\n  checkpoint.\n* `--rules_file`: Path to a CheckpointMigrationConfig textproto file containing\n  the migration rules. See below for the format of this file. If not set, no\n  migration rules will be applied (used for debugging, or to check that the\n  old and new states are identical).\n* `--serialized-model`: If set, use serialized state for a model. Checkpoint\n  already loads serialized as we do not provide schema. This is needed to avoid\n  `GetAttrKey`s.\n* `--checkpoint_step`: If set, use this step when loading from old checkpoint\n  instead of the latest.\n* `--new_checkpoint_step`: If set, use this step when saving the new checkpoint\n  instead of copying the old step.\n\n## Creation of the state to compare\n\nThe checkpoint is loaded into a raw pytree, i.e. template is not passed to it.\n(old checkpoint with new model would fail to load).\n\nRoughly, like this:\n\n```python\n    manager = ocp.CheckpointManager(\n        filepath,\n        options=ocp.CheckpointManagerOptions(create=False),\n    )\n    state = manager.restore(step=None)\n```\n\nThe model state is created using `TrainingState.new_from_config`\n(`src/lczero_training/training/state.py`). If `--serialized-model` is passed,\nthe model state is serialized using `flax.serialization.to_state_dict`.\n\n## Migration rules format\n\nThe migration rules are specified in a `CheckpointMigrationConfig` textproto\nfile. `CheckpointMigrationConfig` is defined in\n`proto/checkpoint_migration_config.proto`\n\nIt contains a list of `CheckpointMigrationRule` messages.\n\nEvery rules has a `from_path` and `to_path` field (both optional, but at least\none must be set). These are string fields, which are json list of strings and\nintegers, and are mapped to pytree KeyPaths:\n\n* Integers are used as `SequenceKey` (list/tuple indices).\n* Strings are used as `DictKey` (dict keys).\n* If other types is met in the path, an error is raised.\n* If other key types are met in `PathKey`, the error is also raised.\n\n* By default, all keys that are present both in the old and new state are\n  preserved (copied from old to new).\n* If both `from_path` and `to_path` are set, they must be different. The old\n  values at `from_path` are copied to `to_path` in the new state.\n* If only `to_path` is set, it means that the new state at `to_path` is\n  taken from the new initialized state.\n* If only `from_path` is set, it means that the old state at `from_path` is not\n  present in the new state, and is ignored. Without this rule, the migration\n  would fail if the old state has keys that are not present in the new state.\n* If we want to keep a initialized (\"new\") value at a path that is also\n  present in the old state, we use two rules: one with only `to_path` to\n  initialize it, and one with only `from_path` to ignore the old value at\n  that path.\n\nNote that the `from_path` and `to_path` are prefixes of the actual paths, so the\nactual subtrees are copied/ignored.\n\n## Working of the migration\n\n1. Both new and old states are created.\n2. Both of them are flattened using `jax.tree_util.tree_flatten_with_path` into\n   list of `(KeyPath, value)` pairs and a treedef.\n3. Lists of `(KeyPath, value)` are converted to dicts mapping `KeyPath` to\n   `value`.\n4. We build a set of `source_paths` (keys of the old state).\n5. We build a set of `dest_paths` (keys of the new state).\n6. Rules are applied in arbitrary order as following:\n   * If both `from_path` and `to_path` are set, the values prefixed by\n     `from_path` are copied to corresponding `to_path` in the new state. Of\n     course, `to_path` in the new state must exist (i.e. we never create new\n     keys). The copied source paths are removed from `source_paths`.\n     The corresponding destination paths are removed from `dest_paths`.\n   * If only `to_path` is set, we delete all paths prefixed by `to_path` from\n     `dest_paths`. This means that we keep the initialized value at this path.\n   * If only `from_path` is set, we delete all paths prefixed by `from_path`\n     from `source_paths`. This means that we ignore the old value at this path.\n7. After all rules are applied, we check that `source_paths` and `dest_paths`\n   are equal. If not, the migration is incomplete, and we raise an error.\n8. After this, we copy all remaining `source_paths` (i.e. those that were not\n   mentioned in any rule) to the new state. This means that by default, all\n   keys that are present both in the old and new state are preserved (copied\n   from old to new).\n9. Finally, we unflatten the new state dict back to a pytree using the new\n   treedef.\n\nIf `--new_checkpoint` is set, we save the new state to the specified path.\nOtherwise, we just print that the migration is possible with the given rules.\n\nInstead of just printing one error, the tool should print ALL errors it finds.\nIf should be helpful to fix the rules.\n\n## Implementation notes\n\n* There is `justfile`, e.g. `just build-proto` to build the protos.\n* We use `uv`. Example usage:\n\n```bash\nuv run migrate-checkpoint --config=~/tmp/lc0/config/overfit.textproto\n```\n"
  },
  {
    "path": "docs/example.textproto",
    "content": "# Example configuration file for lczero-training\n# This file demonstrates all available configuration options with their default\n# values and explanations of what each setting controls.\n\nname: \"little-teapot\"\ndata_loader {\n  stage {\n    name: \"file_path_provider\"\n    file_path_provider {\n      # Directory with training data files.\n      directory: \"/home/crem/tmp/2025-07/lczero-training/data2\"\n      output {\n        queue_capacity: 16 # Internal file queue size\n      }\n    }\n  }\n  stage {\n    name: \"chunk_source_loader\"\n    input: \"file_path_provider\"\n    chunk_source_loader {\n      threads: 1  # Threads for loading chunks\n      frame_format: V6TrainingData  # Training data format (V6TrainingData or V7TrainingData)\n      output {\n        queue_capacity: 16  # Output queue for chunk sources\n      }\n    }\n  }\n  stage {\n    name: \"shuffling_chunk_pool\"\n    input: \"chunk_source_loader\"\n    shuffling_chunk_pool {\n      chunk_pool_size: 50000        # Shuffle buffer size (chunks in memory)\n      source_ingestion_threads: 1   # Threads for ingesting new sources\n      chunk_loading_threads: 4      # Threads for loading chunk data\n      output {\n        queue_capacity: 16  # Output queue for shuffled chunks\n      }\n    }\n  }\n  stage {\n    name: \"chunk_rescorer\"\n    input: \"shuffling_chunk_pool\"\n    chunk_rescorer {\n      threads: 1  # Threads for chunk rescoring\n      output {\n        queue_capacity: 16  # Output queue for rescored chunks\n      }\n      syzygy_paths: \"/path/to/tb\"  # Tablebase search paths (comma-separated)\n      dist_temp: 1.0        # Policy temperature applied during rescoring\n      dist_offset: 0.0      # Policy offset applied during rescoring\n      dtz_boost: 0.0        # DTZ boost for endgame policy tuning\n      new_input_format: -1  # Keep original input format (-1 disables change)\n      deblunder_threshold: 0.10  # Threshold for policy deblundering adjustments\n      deblunder_width: 0.06      # Width controlling smoothing around threshold\n    }\n  }\n  stage {\n    name: \"chunk_unpacker\"\n    input: \"chunk_rescorer\"\n    chunk_unpacker {\n      threads: 1  # Threads for unpacking chunks\n      # Probability of sampling each position within a chunk.\n      position_sampling_rate: 0.03\n      output {\n        queue_capacity: 16  # Output queue for unpacked frames\n      }\n    }\n  }\n  stage {\n    name: \"shuffling_frame_sampler\"\n    input: \"chunk_unpacker\"\n    shuffling_frame_sampler {\n      threads: 1                          # Threads for frame sampling\n      reservoir_size_per_thread: 1000000  # Sampling reservoir per thread\n      output {\n        queue_capacity: 16  # Output queue for sampled frames\n      }\n    }\n  }\n  stage {\n    name: \"tensor_generator\"\n    input: \"shuffling_frame_sampler\"\n    tensor_generator {\n      threads: 1       # Threads for tensor generation\n      batch_size: 128  # Batch size for tensors\n      output {\n        queue_capacity: 8  # Output queue for batched tensors\n      }\n    }\n  }\n  output: \"tensor_generator\"\n}\nmodel {\n  defaults {\n    compute_dtype: F32\n    activation: ACTIVATION_MISH\n    ffn_activation: ACTIVATION_MISH\n  }\n  embedding { dense_size: 512 embedding_size: 1024 dff: 1536 }\n  encoder {\n    num_blocks: 15\n    dff: 1536\n    d_model: 1024\n    heads: 32\n    smolgen {\n      hidden_channels: 32\n      hidden_size: 256\n      gen_size: 256\n      activation: ACTIVATION_SWISH\n    }\n  }\n  policy_head { name: \"vanilla\" embedding_size: 1024 d_model: 1024 }\n  value_head { name: \"winner\" num_channels: 128 }\n  movesleft_head { name: \"main\" num_channels: 32 }\n}\ntraining {\n  schedule {\n    steps_per_network: 250\n    chunks_per_network: 50000\n  }\n  lr_schedule {\n    starting_step: 0\n    duration_steps: 0  # 0 means indefinite duration\n    lr: 0.001\n  }\n  # Example multi-phase schedule with warmup:\n  # lr_schedule {\n  #   starting_step: 0\n  #   duration_steps: [1500, 500, 0]  # Warmup, transition, then indefinite\n  #   lr: [0.0, 0.0005, 0.0005]       # Start at 0, ramp to 0.0005, hold\n  #   transition: [LINEAR, CONSTANT]  # Linear warmup, then constant\n  #   # Missing transitions default to CONSTANT. Last duration 0 = indefinite.\n  # }\n  checkpoint {\n    path: \"/home/crem/tmp/2025-09/lc0_training/checkpoint\"\n    max_to_keep: 5\n  }\n  optimizer {\n    nadamw {\n      beta_1: 0.9  beta_2: 0.98  epsilon: 1e-7  weight_decay: 0.0001\n      # Rule order matters: first match wins.\n      decay_selector {\n        rule { match: \"**/bias\" include: false }\n        rule { match: \"**/ln*/**\" include: false }\n        rule { match: \"**/embedding/embedding/**\" include: false }\n        rule { match: \"**/policy_heads/**\" include: true }\n        rule { match: \"**/value_heads/**\" include: true }\n        rule { match: \"**/movesleft_heads/**\" include: true }\n        otherwise_include: false\n      }\n    }\n    # Alternative optimizers:\n    # nadam { beta_1: 0.9 beta_2: 0.999 epsilon: 1e-8 }\n    # sgd { momentum: 0.9 nesterov: true }\n    # freeze_selector freezes matching weights (they receive no gradient\n    # updates).\n    # freeze_selector {\n    #   rule { match: \"**/embedding/**\" include: true }\n    #   otherwise_include: false\n    # }\n  }\n  max_grad_norm: 10.0  # Global gradient-norm clip; omit or set to 0 to disable.\n  losses {\n    policy {\n      head_name: \"vanilla\"\n      metric_name: \"main_ce\"\n      weight: 1.0\n      illegal_moves: MASK\n      type: CROSS_ENTROPY\n    }\n    policy {\n      head_name: \"vanilla\"\n      metric_name: \"main_kl\"\n      weight: 1.0\n      illegal_moves: MASK\n      type: KL\n      temperature: 1.0  # Softmax temperature applied before KL evaluation\n    }\n    value { head_name: \"winner\" weight: 1.0 }\n    movesleft { head_name: \"main\" weight: 1.0 }\n  }\n}\nmetrics {\n  tensorboard_path: \"/tmp/tensorboard/myrun\"\n}\nexport {\n  destination_filename: \"/home/crem/tmp/2025-08/lc0_training/exported_models/lc0-{datetime}-{step:08d}.pb.gz\"\n  upload_training_run: 3\n}\n"
  },
  {
    "path": "docs/heads.md",
    "content": "# Neural Network Heads Documentation\n\nThis document describes the various policy and value heads used in the network, their training targets, and any specific scaling or transformations applied during training.\n\n## Policy Heads\n\n### 1. Vanilla Policy (`vanilla`)\n*   **Description**: The standard policy head predicting the best move.\n*   **Training Target**: The `probabilities` vector from the training data (MCTS visit counts).\n*   **Scaling/Transformation**: No specific scaling is applied to the target. The loss function compares the network output (logits) directly against the target probability distribution.\n*   **Loss Function**: Cross-entropy loss (Kullback-Leibler divergence).\n\n### 2. Optimistic Short-Term Policy (`optimistic_st`)\n*   **Description**: A policy head trained to be \"optimistic\" about the outcome, focusing more on moves that lead to better short-term evaluations. It uses a weighted loss function where positions that the network underestimates (target > prediction) are weighted more heavily.\n*   **Training Target**: The same `probabilities` vector as the `vanilla` head.\n*   **Scaling/Transformation**:\n    *   **Mechanism**: The \"optimism\" is applied as a **sample weight** in the loss function.\n    *   **Computation Guide**:\n        1.  **Inputs**:\n            *   `v_st_target`: Short-term value target (scalar, from `st` head target).\n            *   `v_st_pred`: Short-term value prediction (scalar, from `st` head output).\n            *   `v_st_err_pred`: Predicted squared error of the short-term value (scalar, from `st_err` head output).\n        2.  **Standard Deviation Estimation**:\n            *   `sigma = sqrt(v_st_err_pred)`\n        3.  **Z-Score Calculation**:\n            *   `z = (v_st_target - v_st_pred) / (sigma + 1e-5)`\n            *   This measures how many standard deviations the target is away from the prediction. A positive `z` means the position is better than predicted (underestimated).\n        4.  **Weight Calculation**:\n            *   `strength = 2.0` (default configuration)\n            *   `weight = sigmoid((z - strength) * 3)`\n    *   **Interpretation**:\n        *   If `z` is large (positive), meaning the target is much higher than predicted (highly underestimated), the weight approaches 1.\n        *   If `z` is small or negative (overestimated or accurately predicted), the weight approaches 0.\n        *   The `strength` parameter shifts the sigmoid, controlling the threshold of \"optimism\" required to trigger training.\n*   **Loss Function**: Weighted Cross-entropy loss. `loss = weight * CrossEntropy(target, output)`.\n\n### 3. Soft Policy (`soft`)\n*   **Description**: A policy head trained on a \"softened\" version of the MCTS probabilities, encouraging exploration or capturing more of the distribution's shape.\n*   **Training Target**: The `probabilities` vector from the training data.\n*   **Scaling/Transformation**:\n    *   **Mechanism**: **Temperature scaling** is applied to the target probabilities **inside the loss function**.\n    *   **Calculation**: `target = target^(1/temperature)`.\n    *   **Location**: This transformation happens in the loss calculation logic (specifically in `correct_policy` helper in `tfprocess.py`), **not** at the tensor generation stage. The tensor generator provides the raw `probabilities`.\n*   **Loss Function**: Cross-entropy loss against the temperature-scaled target.\n\n### 4. Opponent Policy (`opponent`)\n*   **Description**: Predicts the move the opponent actually played in the game.\n*   **Training Target**: `opp_played_idx` (Integer index of the move played by the opponent).\n*   **Scaling/Transformation**: The integer index is converted to a one-hot vector inside the loss function.\n*   **Loss Function**: Cross-entropy loss.\n\n## Value Heads\n\n### 1. Winner (`winner`)\n*   **Description**: Predicts the final game outcome (Win/Draw/Loss).\n*   **Training Target**: A 3-element probability vector derived from `result_q` and `result_d` in the training data.\n    *   `Win = (1 + result_q - result_d) / 2`\n    *   `Loss = (1 - result_q - result_d) / 2`\n    *   `Draw = result_d`\n*   **Scaling/Transformation**: None.\n*   **Loss Function**: Cross-entropy or MSE depending on configuration.\n\n### 2. Q-Value (`q`)\n*   **Description**: Predicts the expected value of the position based on the MCTS search (best Q).\n*   **Training Target**: A 3-element probability vector derived from `best_q` and `best_d` in the training data.\n    *   `Win = (1 + best_q - best_d) / 2`\n    *   `Loss = (1 - best_q - best_d) / 2`\n    *   `Draw = best_d`\n*   **Scaling/Transformation**: None.\n*   **Loss Function**: MSE or Cross-entropy.\n\n### 4. Q-Value Error (`q_err`)\n*   **Description**: Predicts the squared error of the `q` head prediction compared to the target.\n*   **Training Target**: `(q_target - q_pred)^2`.\n*   **Scaling/Transformation**: None.\n*   **Loss Function**: MSE.\n\n### 5. Short-Term Value (`st`)\n*   **Description**: Predicts the short-term evaluation of the position (e.g., from a shallow search or static eval).\n*   **Training Target**: A 3-element probability vector derived from `q_st` and `d_st` in the training data.\n    *   `Win = (1 + q_st - d_st) / 2`\n    *   `Loss = (1 - q_st - d_st) / 2`\n    *   `Draw = d_st`\n*   **Scaling/Transformation**: None.\n*   **Loss Function**: MSE or Cross-entropy.\n\n### 6. Short-Term Value Error (`st_err`)\n*   **Description**: Predicts the squared error of the `st` head prediction compared to the target. Used for calculating uncertainty/variance for the `optimistic_st` head.\n*   **Training Target**: `(st_target - st_pred)^2`.\n*   **Scaling/Transformation**: None.\n*   **Loss Function**: MSE.\n\n## Auxiliary Heads\n\n### 1. Moves Left (`moves_left`)\n*   **Description**: Predicts the number of moves remaining in the game.\n*   **Training Target**: `plies_left` (Float).\n*   **Scaling/Transformation**:\n    *   **Mechanism**: The target and output are scaled down by a factor (e.g., 20.0) **inside the loss function** to bring the loss magnitude into a similar range as other losses.\n*   **Loss Function**: Huber loss.\n"
  },
  {
    "path": "docs/index.md",
    "content": "# Index\n\n* [Overview, glossary and file formats](overview.md) — A an overview of the\n  project, including definitions of some terms and file formats used.\n* [Data Loader](loader.md) — A module for loading, preprocessing, shuffling, and\n  feeding training data.\n* [Training Tuple Format](training_tuple.md) — A description of the training\n  tuple format used in the project. This is an interface between data loader and\n  model training.\n* [Command-Line Tools](cli.md) — How to run the tools via `uv run`.\n"
  },
  {
    "path": "docs/loader.md",
    "content": "# Data Loader\n\nThe Data Loader is a C++ module (exposed to Python via pybind11) that handles\nloading, preprocessing, shuffling, and feeding training data for the Leela Chess\nZero training process.\n\n## Python Integration\n\nThe loader has been exposed to Python through pybind11, allowing direct use of\nthe C++ `DataLoader` from Python code. Key aspects:\n\n* **Configuration**: Generated protobufs (for example `DataLoaderConfig`) are\n  passed directly to the binding, or via the convenience wrapper\n  `lczero_training.dataloader.make_dataloader`.\n* **Control Plane**: Use `DataLoader.send_control_message()` with\n  `proto.stage_control_pb2.StageControlRequest` to fan out commands such as\n  chunk-pool anchor updates.\n* **Memory Management**: Uses `unique_ptr::release()` with `py::return_value_policy::take_ownership` for efficient tensor ownership transfer\n* **Output Format**: Returns tuple of numpy arrays compatible with JAX through the buffer protocol\n* **Usage**: `from lczero_training.dataloader import make_dataloader`\n\n## High-Level Overview\n\nThe Data Loader consists of the following stages connected through a\n[Queue](../csrc/utils/queue.h):\n\n* [FilePathProvider](../csrc/loader/chunk_feed/file_path_provider.h) — Training\n  data discovery worker (watches a directory and provides feed of filenames)\n* [ChunkSourceLoader](../csrc/loader/chunk_feed/chunk_source_loader.h) — Reads\n  chunks from files, providing a stream of chunks.\n* [ShufflingChunkPool](../csrc/loader/chunk_feed/shuffling_chunk_pool.h) — Keeps\n  a set of chunks, managing the last `num_chunks` available and removing old\n  ones, and outputting them in shuffled order.\n* (skip for now) [ChunkValidator](../csrc/loader/chunk_feed/chunk_validator.h) —\n  Filters the chunk stream, filtering out invalid chunks.\n* [ChunkRescorer](../csrc/loader/stages/chunk_rescorer.h) — Rescores chunks\n  using Syzygy tablebases and configurable policy adjustments.\n* [ChunkUnpacker](../csrc/loader/chunk_feed/chunk_unpacker.h) — Unpacks\n  chunks into frames, which are then processed by the next stages.\n* [ShufflingFrameSampler](../csrc/loader/shuffling_frame_sampler.h) — Takes a\n  stream of frames and provides shuffled batches of frames for training, using\n  reservoir sampling.\n* [TensorGenerator](../csrc/loader/tensor_generator.h) — Takes frames and\n  provides tensor buffers for the training process.\n\n## Metrics\n\nAll stages expose the following metrics in\n[DataLoaderMetricsProto](../proto/training_config.proto):\n\n* load — for measure how much time the threads are working vs idle.\n* queue - for monitoring queue statistics.\n\nThere are the following exceptions:\n\n* ShufflingChunkPool\n  * Has two thread pools (indexing and chunk loading), so needs two `load`\n    metrics.\n  * Needs metric (statisticsmetric) for current number of chunk sources.\n  * Needs metric (simple value) for current number of chunks in the pool.\n  * Needs metric (simple value) for pool capacity.\n\n* ChunkUnpacker\n  * Needs to track the number of bad chunks (statisticsmetric)\n\n* ShufflingFrameSampler\n  * Needs capacity of reservoir (simple value)\n  * Needs current size of reservoir (simple value)\n\n### Adding a new metric\n\nTo add a new metric, you need to:\n\n* Add a field in the relevant proto in\n  [training_config.proto](../proto/training_config.proto)\n* Add a field in a relevant stage to collect the metric (if needed).\n* Implement FlushMetrics() function in your stage. It should reset internal\n  state metrics to zero, and return what it accumulated (or latest value).\n* In [DataLoader::MetricsThread](../csrc/loader/data_loader.cc) call\n  FlushMetrics() of the relevant stage and assign it to proper proto field.\n* In the proper\n  [Queue or Stage widget](../src/lczero_training/tui/stage_widgets.py) create\n  proper UI elements and update update_metrics() function to update them.\n\n* For load metrics, you have to create LoadMetricPauser per thread, e.g. see [ShufflingChunkPool](../csrc/loader/chunk_feed/shuffling_chunk_pool.cc).\n* For queue metrics, just collect the queue statistics in FlushMetrics().\n* For all metrics, both \"all time\" and \"during last second\" is provided. The\n  latter is useful to determine rate (items per second).\n* In general, use StatisticsProtoInt64 (or StatisticsProtoDouble) for\n  distribution metrics, and simple number field for additive metrics like\n  counts.\n\n## TensorGenerator\n\nBatch size is configurable in the stage options.\n\nThe `TensorGenerator` stage takes frames from the input queue and produces tuple\nof tensors for tensor returned as [TensorTuple](../csrc/utils/tensor.h).\nThe first dimension of every tensor in the tuple is the batch size, and the\nrest are described in the [Training Tuple Format](training_tuple.md).\n\n## Stage interface\n\nAll stages implement the similar API and structure, although not sharing any\nbase class.\n\nAll stages/workers (except pure producers) wait for the input queue to Close(),\nthen Close() output Queue.\n\n```cpp\nclass Stage {\n public:\n    using InputType = ...;  // Type of input data for this stage\n    using OutputType = ...; // Type of output data from this stage\n    // input_queue is omitted in the producer stages like FilePathProvider.\n    Stage(Queue<InputType>* input_queue, /* other params */);\n    Queue<OutputType>* output();\n\nprivate:\n    ThreadPool thread_pool_;\n    Queue<InputType>* input_queue_;\n    Queue<OutputType> output_queue_;\n};\n```\n\n## Chunk Set\n\nThe Chunk Set takes the feed of chunk sources, indexes them, and assigns the\nchunk range (base; base+num_chunks) to each chunk source. It aims to keep the\nnewest chunk sources that cover the last `chunks_window_` chunks, and removes\nold chunk sources when new ones are added.\n\nOn the output side, it returns a stream of chunks within\n(`last - chunk_window_`, `last`) range, without repetitions. To do that, it\nutilizes a [StreamShuffler](../csrc/loader/stream_shuffler.h) to which provides\nthe shuffled stream of numbers within the (dynamic) range.\n\nThe Chunk Set gets the stream of chunks (initial chunks are read in the\nconstructor), and then starts:\n\n* Input indexing worker pool. Input indexing worker pool calls `Index()` on each\n  chunk source, and then appends the chunk source to the `chunk_sources_` deque\n  (under mutex).\n\n* Chunk output worker pool. Fetches the next number from `stream_shuffler_`\n  under mutex, then reads the chunk from the chunk source using per-source mutex.\n\nIf `stream_shuffler_` runs out of numbers, it's reset to the range\n(`last - chunk_window_`, `last`) (and warning message is logged).\n\n## ShufflingFrameSampler\n\nThe sampler uses reservoir sampling:\n\n* It has a reservoir of predefined size (1000000 is quite typical)\n* Initially it just fills the reservoir with frames from the input queue until\n  it's full.\n* After that, it picks random frames from the reservoir and outputs them,\n  refilling the used spot from the input queue.\n* It closes the output queue when either explicit Close() is called or the input\n  queue is closed.\n* `using FrameType = V6TrainingData;`, use `absl::FixedArray<FrameType>` for\n  the reservoir.\n\n## Anchors in ShufflingChunkPool\n\nThe training pipeline aims to start training new epoch when a certain number\nof new chunks are available.\nTo do that, we reset the running counter of chunks in the ShufflingChunkPool\nwhen we start waiting for new chunks, and then wait until the counter reaches\nthe desired number.\nHowever, when the script restarts, we also want to approximately know how many\nnew chunks to wait for before starting the training.\n\nTo do that, we use the concept of \"anchors\". Anchor is just a GetChunkSortKey()\nof a given chunk that we remember.\n\nMore specifically, we add the following functions to ShufflingChunkPool:\n\n* `std::string ResetAnchor()` — resets the anchor to the latest chunk seen so\n  far, and returns its sort key. Also resets the internal counter of chunks seen\n  since the anchor.\n* `int ChunksSinceAnchor()` — returns the number of chunks seen since the anchor.\n* `std::string CurrentAnchor()` — returns the current anchor sort key.\n* `void SetAnchor(std::string_view)` — is usually called BEFORE starting processing\n  chunks. Does not reset the counter, but sets the anchor to the given value.\n  When the read chunk has the same key as the anchor, the counter is reset to zero.\n\nPython clients access this functionality by issuing\n`StageControlRequest` messages through\n`DataLoader.send_control_message()`. The daemon pipeline demonstrates how the\nfirst chunk-pool response is used to update anchor state.\n\nThe anchor functionality works differently during initial load vs. ongoing processing:\n\n**Initial Load (backward processing):**\n\n* Chunks are processed in newest-first order during initial scan\n* If no anchor is set: count all chunks\n* If anchor is set: only count chunks newer than the anchor; reset counter to 0\n  when anchor is encountered\n\n**Ongoing Processing (after initial load):**\n\n* New chunks are processed as they arrive\n* Counter increments by GetChunkCount() for each new chunk source\n* Counter resets to 0 when a chunk source matching the anchor is encountered\n\nMetrics:\nIn [ShufflingChunkPoolMetricsProto](../proto/training_metrics.proto) we add:\n\n* `int32 chunks_since_anchor` — number of chunks seen since the anchor. Simple\n  numerical field.\n* `string anchor` — current anchor sort key.\n\nIn [UI](../src/lczero_training/tui/stage_widgets.py) we:\n\n* Update the last_chunk_key display to have a label \"Last:\"\n* Add a new row \"⚓:\" for the anchor key.\n* Add a new row \"Since ⚓:\" for chunks_since_anchor.\n"
  },
  {
    "path": "docs/new_stage.md",
    "content": "# Writing a New Data Loader Stage\n\nThis guide walks through the lifecycle of adding another stage to the dynamic\ndata loader pipeline. It assumes you are working in the C++ orchestrator (under\n`csrc/loader/stages/`) and that the Python bindings already consume staged\nconfigurations.\n\n## 1. Design the Stage Surface\n\n- **Purpose and data flow**: Decide whether the stage produces new data (no\n  upstream input) or transforms items from an existing queue.\n- **Configuration shape**: Determine which knobs are required during\n  construction (thread counts, capacities, etc.). These become fields on the\n  stage-specific protobuf message.\n- **Outputs and control hooks**: Clarify what queue type the stage emits and\n  whether it needs control-plane messages.\n\n## 2. Extend the Protobufs\n\n- Add a new `message <YourStage>Config` to `proto/data_loader_config.proto`.\n- For single-output stages, add `optional QueueConfig output = N` to configure\n  the output queue. `QueueConfig` provides `queue_capacity` (default 4),\n  `overflow_behavior` (BLOCK, DROP_NEW, KEEP_NEWEST), and optional `name`.\n- For multi-output stages, use `repeated QueueConfig output` with parallel\n  configuration arrays (see `ChunkSourceSplitterConfig` for reference).\n- Update `StageConfig` with an `optional <YourStage>Config` entry so the stage\n  can be referenced from the `repeated stage` list.\n- If the stage emits custom metrics, extend `StageMetricProto` in\n  `proto/training_metrics.proto`. Prefer the existing `load_metrics`,\n  `queue_metrics`, and `count_metrics` collections when possible.\n- When the stage needs control requests or responses, extend\n  `proto/stage_control.proto` so they can be carried through\n  `StageControlRequest`/`StageControlResponse`.\n- Regenerate protobufs (`meson compile -C builddir` or `just build-proto`).\n\n## 3. Choose a Base Class\n\n- **Use `SingleInputStage<ConfigT, InputT>`** when the stage consumes exactly\n  one upstream queue. The helper provides `input_queue()` to access the typed\n  `Queue<InputT>*` and implements `SetInputs()` to wire the input during\n  initialization.\n- **Use `SingleOutputStage<OutputT>`** when the stage produces exactly one\n  output queue. The helper manages the output queue, implements `GetOutput()`\n  with name validation, and surfaces the typed `Queue<OutputT>*` via\n  `output_queue()`.\n- **Most stages inherit from both** `SingleInputStage` and `SingleOutputStage`\n  using virtual inheritance (both base classes virtually inherit from `Stage`\n  to avoid the diamond problem). Example:\n  ```cpp\n  class MyStage : public SingleInputStage<MyStageConfig, InputType>,\n                  public SingleOutputStage<OutputType> {\n   public:\n    explicit MyStage(const MyStageConfig& config)\n        : SingleInputStage<MyStageConfig, InputType>(config),\n          SingleOutputStage<OutputType>(config.output()) {}\n  };\n  ```\n- **Inherit `Stage` directly** when the stage has multiple inputs, multiple\n  outputs, or manages more complex wiring. In that case you must implement\n  `SetInputs()`, input/output discovery, and `GetOutput()` yourself.\n- Place declarations in `csrc/loader/stages/<stage_name>.h` and definitions in\n  the matching `.cc` file.\n\n## 4. Implement the Stage API\n\n- **Constructor**: Initialize base classes with config and `config.output()`.\n  Store additional config fields and initialize worker pools.\n  Avoid starting threads here.\n- **`SetInputs(absl::Span<QueueBase* const> inputs)`**: Only implement if you\n  inherit from `Stage` directly. `SingleInputStage` provides this automatically\n  and validates that exactly one input is provided. For stages with no inputs,\n  validate that the span is empty.\n- **`Start()`**: Launch background work. Acquire `Queue::Producer` instances\n  from `output_queue()->CreateProducer()` for emitting data and honour\n  `stop_requested_` flags so shutdown is cooperative. The input queue is\n  available via `input_queue()` at this point.\n- **`Stop()`**: Close queues via `output_queue()->Close()`, signal workers to\n  exit, and join threads. Remember that downstream stages expect\n  `Queue::Close()` to signal completion.\n- **`GetOutput(std::string_view name)`**: Only implement if the stage has\n  multiple outputs. `SingleOutputStage` provides this automatically for\n  single-output stages, including name validation.\n- **`Control()`**: Handle relevant `StageControlRequest` sub-messages and return\n  a populated `StageControlResponse` wrapped in `std::optional`. Return\n  `std::nullopt` for requests the stage does not recognise.\n\n## 5. Report Metrics\n\n- **Accumulate state** while workers run (e.g., load metrics, counters,\n  queue statistics).\n- **`FlushMetrics()`** should snapshot the current values, reset internal\n  counters as needed, and populate `StageMetricProto`. Use helpers like\n  `MetricsFromQueue(\"output\", *output_queue())` to expose queue utilisation\n  under `queue_metrics`, and append load information via `load_metrics`.\n- For multiple queues or distinct metric groups, add additional entries with\n  meaningful names (`\"output\"`, `\"prefetch\"`, etc.) so downstream tooling can\n  pick the right series.\n- If you rename or split a metric, document the change and update dashboards.\n  For example, `ShufflingChunkPool` now emits `chunks_current` (window size)\n  and `chunks_total` (total indexed chunks) instead of a single `chunks`\n  series; the Grafana panels consuming the old series were repointed to\n  `chunks_current` so the graphs remain accurate.\n\n## 6. Register the Stage\n\n- Update `CreateStage` in `csrc/loader/stages/stage_factory.cc` to construct the\n  new class when its config is present. Enforce the \"exactly one sub-config\"\n  rule by keeping the existing `CountStageConfigs()` logic in sync.\n- Ensure `meson.build` lists the new source files so the static library rebuilds.\n\n## 7. Wire Up Tests\n\n- Add focused unit tests under `csrc/loader/stages/` validating constructor\n  errors, thread lifecycle, metric flushing, and (if applicable) control-plane\n  behaviour.\n- Provide integration coverage where the stage participates in a small pipeline\n  built from serialized `DataLoaderConfig` messages.\n- If Python bindings surface stage-specific behaviour, extend the relevant\n  `pytest` suites too.\n\n## 8. Update Documentation and Examples\n\n- Document new config fields in `docs/` (for example, augment `docs/loader.md`\n  or create stage-specific notes).\n- Add sample snippets or textproto fragments showing how to reference the\n  stage in a pipeline.\n- Mention any new control commands so the daemon/TUI maintainers know how to\n  surface them.\n\nFollowing these steps keeps the stage ecosystem consistent: configurations are\nvalidated at construction time, queues remain type-safe, metrics feed the UI,\nand Python clients continue to operate through the generic factory and control\nplane.\n"
  },
  {
    "path": "docs/overview.md",
    "content": "# Overview, Glossary, and File Formats\n\nThis document serves as a glossary of terms used in the project and describes\nthe file formats utilized.\n\n* Training data **frame** is a data structure that holds information needed for\n  the NN training about a single chess position. Currently it's a fixed sized\n  struct, e.g. [V6TrainingData](../libs/lc0/src/trainingdata/trainingdata.h).\n* **Chunk** is a sequence of frames from a single game. Currently, they are\n  stored in a gzipped file where frames are concatenated together.\n* **Chunk Source** is a file that contains one or more chunks. In older versions\n  of the code, it was a single gzipped file. In the new version, it also may be\n  an uncompressed .tar file.\n"
  },
  {
    "path": "docs/shuffling_pool_hanse_sampling.md",
    "content": "# Implement single position sampling in Shuffling Pool\n\nThis document defines a new way of sampling in\n[Shuffling Pool](../csrc/loader/stages/shuffling_chunk_pool.h) (and .cc).\n\n## reshuffle_count -> use_count\n\nIn `TrainingChunk` in\n[training_chunk.h](../csrc/loader/stages/training_chunk.h) the `reshuffle_count`\nshould be renamed to `use_count`. As with the new sampling method, usage is not\nnecessarily tied to reshuffling. Also the code that\n[uses it](../csrs/loader/stages/chunk_unpacker.cc) will need to be updated.\n\nIn `ChunkSourceItem`, instead of one `reshuffle_count` per entire chunk source,\nwe'll have a `std::vector<uint16_t> use_counts`, initially filled with zeros.\nInstead of updating it on reshuffling, we will update it for individual chunks\nwhen they are returned. In `GetNextChunkData()` we return old value (i.e. 0 for\nthe first time).\n\nLocal struct `ShufflingChunkPool::ChunkData` in\n../csrc/loader/stages/shuffling_chunk_pool.cc should also have `use_count`\ninstead of `reshuffle_count`.\n\n## Configuration changes\n\nIn the [config](../proto/data_loader_config.proto), in `ShufflingChunkPoolConfig`,\nwe add the following fields:\n\n```proto\nmessage ShufflingChunkPoolConfig {\n  // existing fields...\n  optional uint64 hanse_sampling_threshold = 6;  // by default, do not use new sampling.\n  optional double hanse_sampling_gamma = 7 [default = 1.0];\n}\n```\n\n## Algorithm changes\n\nIn addition to `use_counts`, `ChunkSourceItem` will have a new field:\n`std::vector<uint16_t> num_records;`. It will contain the number of records in\neach chunk, and will act as a cache. Initially, it's filled with zeros.\n\nWhen `record_bound` is not set, the sampling method is the same as before.\n\nWhen `record_bound` is set, we will use the new sampling method. It goes like\nthis:\n\n`GetNextChunkData()` doesn't call `LoadChunkData()` right away.\nInstead, it first checks if `num_records[chunk_index]` is zero. If so, it\ncalls `LoadChunkData()` to load the chunk, and counts the number of records in\nit (by dividing its size to `sizeof<FrameType>`).\n\nThen, we decide whether to return this chunk or to sample again. We do this by\ndrawing a random number `u` uniformly from `[0, 1)`, and comparing it to\n`p = min(1.0, num_records/hthreshold) ^ gamma`. If `u < p`, we return this chunk\n(we need to call `LoadChunkData()` if we didn't already) and increment\nuse_count. Otherwise, we sample again (i.e. pick a new `chunk_index` and repeat\nthe process) — without incrementing `use_count`."
  },
  {
    "path": "docs/training_tuple.md",
    "content": "# Training Tuple Format\n\nThe `convert_v6_to_tuple` function in `tf/chunkparser.py` processes training\ndata and produces a 5-element tuple:\n`(planes, probs, winner, best_q, plies_left)`.\n\nWhen these raw byte strings are interpreted as NumPy arrays, they have the\nfollowing shapes for each training example:\n\n1.  **`planes`**: `(112, 64)` as a `float32` array.\n    * This represents the board state as 112 feature planes, each of size 8x8\n        (64). The original 104 planes from the input are augmented with 8\n        additional planes for information like castling rights, side to move,\n        the rule 50 count, and board edge detection.\n\n2.  **`probs`**: `(1858,)` as a `float32` array.\n    * This corresponds to the `float probabilities[1858]` member in the\n        `V6TrainingData` C++ struct, representing the policy probabilities for\n        all possible moves.\n\n3.  **`winner`**: `(3,)` as a `float32` array.\n    * This holds the game's outcome from the current player's perspective,\n        representing the probabilities for a win, draw, and loss, respectively.\n\n4.  **`best_q`**: `(3,)` as a `float32` array.\n    * Similar to `winner`, this stores the value of the position after search\n        (the Q-value), also represented as win, draw, and loss probabilities.\n\n5.  **`plies_left`**: A scalar `float32`.\n    * This value represents the estimated number of plies remaining until the\n        end of the game.\n"
  },
  {
    "path": "docs/tui.md",
    "content": "# UI Design\n\nThe application will present a single-screen dashboard with a classic blue background, organized into several key, always-visible panes.\n\n## 1. Overall Layout\n\nThe screen is divided into four main sections:\n\n1. **Header Bar (Top):** A slim, single-line bar at the very top.\n2. **Data Pipeline Pane (Top-Left/Main):** The largest and most detailed pane.\n3. **JAX Training Status Pane (Right):** A pane dedicated to live training metrics.\n4. **Log Pane (Bottom):** A pane for displaying raw log output.\n\n## 2. Header Bar\n\nThis bar provides high-level, global status at a glance.\n\n- **Uptime:** Total wall-clock time since the script was launched.\n- **Overall Stage:** The current high-level state of the application. Will display one of: `WAITING FOR DATA`, `TRAINING`, `EXPORTING`, or `ERROR`.\n\n### 3. Data Pipeline Pane\n\nThis pane visualizes the flow of data through the C++ pipeline. The new\ndesign is a vertically scrollable list where every stage and queue consumes\none row (two only when extra detail is needed). Rows are rendered in the same\norder as traffic flows through the loader, preserving the mental model of the\npipeline without relying on borders or grid positioning.\n\n- **Pipeline Stages (Rows):** Each stage of the C++ pipeline is rendered as a\n  single line.\n  - **Stage Heading:** The label uses the canonical stage name reported in the\n    metrics, so newly added stages automatically appear without additional UI\n    wiring.\n  - **Load Metrics:** Stages with thread pools render load in `load\n    active/total` format.\n  - **Specific Stats:** Each metric is rendered as an individual \"chip\" widget,\n    e.g. skipped file counters or pool sizes, so additional detail can be added\n    without breaking alignment.\n\n- **Queues (Rows):** Each stage row is followed by a queue row that surfaces the\n  metrics of the outgoing queue.\n  - **Queue Heading:** The row label is the canonical stage name reported by the\n    daemon. The first chip shows the queue name when it differs from the\n    default.\n  - **Throughput:** A chip displays the 1-second `items/s` rate and turns red\n    when the rate drops to zero.\n  - **Totals:** A chip shows the lifetime count of elements that passed through\n    the queue, formatted with apostrophes.\n  - **Fill State:** A horizontal progress bar visualises the average queue fill\n    against capacity, followed by a chip with the numeric `avg/capacity`\n    display. Unknown values fall back to `--`.\n\n- **Train/Validation/Test Splitter:**\n  - The pipeline view will show a \"Stream Splitter\" stage.\n  - Hotkeys (e.g., F1, F2, F3) will allow the user to instantly switch the view\n  - After this stage, the UI will display the pipeline stats for **one** stream\n    at a time (defaulting to 'Training'). to show the stats for the 'Training',\n    'Validation', or 'Test' streams.\n\n### 4. Training schedule/pipeline pane\n\nThe area below the Data Pipeline Pane is split horizontally into two\nsections: Training Schedule (left) and JAX Training Status (right).\n\nThe Training Schedule pane shows:\n\n- **Combined uptime and stage line**: \"Uptime: 2d 14:30:45   Stage: TRAINING\"\n  - Uptime includes days when >24 hours (format: \"2d 14:30:45\")\n  - Stage shows current training state (WAITING_FOR_DATA, TRAINING, EXPORTING, ERROR)\n- **Completed epochs**: Simple counter of epochs completed since daemon start\n- **New chunks progress bar**: Shows chunks collected since training start vs. target\n  - Indeterminate state when target is unknown (0)\n- **Training time progress bar**: Current training time vs. previous training duration\n  - Indeterminate state when no previous duration exists\n- **Cycle time progress bar**: Current cycle time vs. previous cycle duration\n  - Indeterminate state when no previous duration exists\n\n**Implementation details:**\n\n- **Header bar**: Completely empty (no content)\n- **Data structure**: `TrainingScheduleData` dataclass with all timing fields:\n  - `current_stage: TrainingStage` (enum)\n  - `completed_epochs_since_start: int`\n  - `new_chunks_since_training_start: int`\n  - `chunks_to_wait: int`\n  - `total_uptime_seconds: float`\n  - `current_training_time_seconds: float`\n  - `previous_training_time_seconds: float`\n  - `current_cycle_time_seconds: float`\n  - `previous_cycle_time_seconds: float`\n- **Timing computation**: All timing values computed in daemon, not TUI\n- **Progress bars**: Show indeterminate state when maximum values ≤ 0\n- **Layout**: Compact single-line widgets with no extra padding/margins\n- **Files**:\n  - `training_widgets.py`: Widget implementations\n  - `pipeline.py`: Training state tracking and data collection\n  - `daemon.py`: Metrics collection and transmission\n  - `messages.py`: Protocol definitions with enum serialization support\n\n### 5. JAX Training Status Pane\n\nThis pane is dedicated to the live status of an active JAX training run. It\nremains blank or shows summary info when the system is not actively training.\n\n- **Epoch Progress:** A progress bar showing completion of the current epoch\n  (`Step 12345 / 50000`).\n- **Performance Metrics:**\n  - Steps per second: `345.6 steps/s`.\n  - Estimated Time Remaining (ETR) for the current run.\n  - Total wall time spent on the current training run.\n- **Loss Values (Numerical Only):**\n  - A prominent display of the **Total Loss**.\n  - A compact 2-column grid displaying the individual values for the **7 Head\n    Losses**.\n\n### 6. Log Pane\n\nA pane across the bottom of the screen.\n\n- **Content:** A direct feed of all output sent to `stderr` from any part of the\n  application (Python or C++).\n- **Functionality:** The pane will be scrollable and will hold a fixed number of\n  lines (e.g., 1000) to prevent unbounded memory usage, discarding the oldest\n  lines as new ones arrive.\n"
  },
  {
    "path": "docs/weights_tool.md",
    "content": "# lc0-weights - Weight Manipulation Tool\n\n## Overview\n\n`lc0-weights` is a command-line tool and Python library for manipulating Leela Chess Zero neural network weight files. It enables arithmetic operations on networks, component grafting, and format conversion without requiring JAX or other heavy dependencies.\n\n**Key capabilities:**\n\n* **Arithmetic operations**: Add, subtract, multiply networks (e.g., model interpolation/averaging)\n* **Grafting**: Replace specific network components (policy heads, value heads, encoder layers)\n* **Format conversion**: Convert between LINEAR16, FLOAT16, and BFLOAT16 encodings\n* **Pure Python/NumPy**: No JAX or TensorFlow dependencies required\n\n## Quick Start\n\nInterpolate two networks with equal weights:\n\n```bash\n# Using inline expression\nuv run lc0-weights \\\n  --expr \"output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5\" \\\n  --output interpolated.pb.gz\n\n# Using a script file\necho \"output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5\" > interpolate.py\nuv run lc0-weights interpolate.py --output interpolated.pb.gz\n\n# Using stdin\necho \"output = weights('network_a.pb.gz') * 0.5 + weights('network_b.pb.gz') * 0.5\" | \\\n  uv run lc0-weights --output interpolated.pb.gz\n```\n\n## Command-Line Interface\n\n### Arguments\n\n| Argument     | Required | Description                                                          |\n| ------------ | -------- | -------------------------------------------------------------------- |\n| `--expr`     | No       | Python expression to execute                                         |\n| `script`     | No       | Path to Python script file (positional, used if --expr not given)   |\n| `--input`    | No       | Pre-load input as `NAME=PATH` (can be used multiple times)          |\n| `--output`   | No       | Output path (if `output` variable is set in expression)             |\n| `--encoding` | No       | Output encoding format: LINEAR16, FLOAT16 (default), or BFLOAT16    |\n\n**Note:** You must provide either `--expr`, a script file path, or pipe input via stdin.\n\n### Available Functions in --expr\n\nWhen executing expressions with `--expr`, the following are available:\n\n* `weights(path)`: Load a weight file\n* `save(net, path, encoding='FLOAT16')`: Save a network\n* `np`: NumPy module\n* `lc0`: Protobuf module (for accessing constants)\n\n### CLI Examples\n\n#### Simple Interpolation\n\n**Using inline expression:**\n```bash\nuv run lc0-weights \\\n  --expr \"output = weights('A.pb.gz') * 0.5 + weights('B.pb.gz') * 0.5\" \\\n  --output result.pb.gz\n```\n\n**Using a script file:**\n```bash\n# Create script file\ncat > interpolate.py << 'EOF'\nA = weights('A.pb.gz')\nB = weights('B.pb.gz')\noutput = A * 0.5 + B * 0.5\nEOF\n\n# Run script\nuv run lc0-weights interpolate.py --output result.pb.gz\n```\n\n**Using stdin:**\n```bash\necho \"output = weights('A.pb.gz') * 0.5 + weights('B.pb.gz') * 0.5\" | \\\n  uv run lc0-weights --output result.pb.gz\n```\n\n#### Using Input Aliases\n\nPre-load networks to simplify expressions:\n\n```bash\nuv run lc0-weights \\\n  --input A=network_a.pb.gz \\\n  --input B=network_b.pb.gz \\\n  --expr \"output = A * 0.9 + B * 0.1\" \\\n  --output result.pb.gz\n```\n\n#### Grafting a Policy Head\n\nReplace the policy head of one network with another:\n\n**Using inline expression:**\n```bash\nuv run lc0-weights --expr \"\nbase = weights('base_network.pb.gz')\ndonor = weights('network_with_better_policy.pb.gz')\nbase.weights.policy = donor.weights.policy\nbase.save('grafted.pb.gz')\n\"\n```\n\n**Using a script file (recommended for multi-line operations):**\n```bash\n# Create script\ncat > graft_policy.py << 'EOF'\nbase = weights('base_network.pb.gz')\ndonor = weights('network_with_better_policy.pb.gz')\nbase.weights.policy = donor.weights.policy\nbase.save('grafted.pb.gz')\nEOF\n\n# Run script\nuv run lc0-weights graft_policy.py\n```\n\n#### Format Conversion\n\nConvert a network to BFLOAT16 encoding:\n\n```bash\nuv run lc0-weights \\\n  --input net=network.pb.gz \\\n  --expr \"output = net\" \\\n  --output converted.pb.gz \\\n  --encoding BFLOAT16\n```\n\n#### Complex Expression\n\nWeighted average with custom formula:\n\n```bash\nuv run lc0-weights \\\n  --input A=net1.pb.gz \\\n  --input B=net2.pb.gz \\\n  --input C=net3.pb.gz \\\n  --expr \"output = A * 0.5 + B * 0.3 + C * 0.2\" \\\n  --output averaged.pb.gz\n```\n\n## Python Library Usage\n\n### Importing\n\n```python\nfrom lczero_training.tools import load_weights, save_weights\n```\n\n### Loading and Saving Weights\n\n```python\n# Load a network\nnet = load_weights(\"network.pb.gz\")\n\n# Save with different encoding\nsave_weights(net, \"output.pb.gz\", encoding=\"FLOAT16\")\n```\n\n### Arithmetic Operations\n\n```python\n# Load networks\nnet_a = load_weights(\"network_a.pb.gz\")\nnet_b = load_weights(\"network_b.pb.gz\")\n\n# Interpolation (model averaging)\ninterpolated = net_a * 0.7 + net_b * 0.3\n\n# Addition\ncombined = net_a + net_b\n\n# Subtraction\ndifference = net_a - net_b\n\n# Scalar multiplication\nscaled = net_a * 0.5\n\n# Save result\nsave_weights(interpolated, \"result.pb.gz\")\n```\n\n### Accessing and Modifying Components\n\n```python\nnet = load_weights(\"network.pb.gz\")\n\n# Access nested weight arrays\nq_weights = net.weights.encoder[0].mha.q_w.value  # Returns NumPy array\nprint(q_weights.shape)\n\n# Modify weights\nnet.weights.encoder[0].mha.q_w.value = q_weights * 1.1\n\n# Save modified network\nsave_weights(net, \"modified.pb.gz\")\n```\n\n### Grafting Components\n\n```python\nbase = load_weights(\"base.pb.gz\")\ndonor = load_weights(\"donor.pb.gz\")\n\n# Replace policy head\nbase.weights.policy = donor.weights.policy\n\n# Replace value head\nbase.weights.value_heads = donor.weights.value_heads\n\n# Replace specific encoder layer\nbase.weights.encoder[0] = donor.weights.encoder[0]\n\nsave_weights(base, \"grafted.pb.gz\")\n```\n\n## Weight Encoding Formats\n\nThe tool supports three encoding formats for weight storage:\n\n| Format   | Description                                                                      | Precision | File Size |\n| -------- | -------------------------------------------------------------------------------- | --------- | --------- |\n| LINEAR16 | Quantized 16-bit integer with min/max range. Default for Lc0.                   | ~4 digits | Smallest  |\n| FLOAT16  | Native IEEE 754 half-precision floating point.                                   | ~3 digits | Medium    |\n| BFLOAT16 | Brain float 16 (truncated float32). Better range than FLOAT16, less precision.  | ~2 digits | Medium    |\n\n**Notes:**\n\n* LINEAR16 provides good compression with acceptable precision for neural networks\n* FLOAT16 is the default for this tool (good balance of precision and size)\n* BFLOAT16 is useful when training range matters more than mantissa precision\n* All formats are converted to float32 when loaded for arithmetic operations\n\n## Common Use Cases\n\n### Network Interpolation (Model Averaging)\n\nCombine two networks to create a smoother model or blend different training runs:\n\n```python\nfrom lczero_training.tools import load_weights, save_weights\n\nnet1 = load_weights(\"run1_final.pb.gz\")\nnet2 = load_weights(\"run2_final.pb.gz\")\n\n# Average the networks\naveraged = net1 * 0.5 + net2 * 0.5\n\nsave_weights(averaged, \"averaged_network.pb.gz\")\n```\n\n### Exponential Moving Average (EMA)\n\nUpdate a running average network with a new checkpoint:\n\n```python\nema_net = load_weights(\"ema.pb.gz\")\nnew_net = load_weights(\"latest_checkpoint.pb.gz\")\n\n# EMA with decay 0.999\nema_updated = ema_net * 0.999 + new_net * 0.001\n\nsave_weights(ema_updated, \"ema.pb.gz\")\n```\n\n### Policy Head Replacement\n\nReplace a network's policy head (useful for policy distillation):\n\n```python\nstudent = load_weights(\"student_network.pb.gz\")\nteacher = load_weights(\"teacher_network.pb.gz\")\n\n# Replace student's policy head with teacher's\nstudent.weights.policy_heads = teacher.weights.policy_heads\n\nsave_weights(student, \"student_with_teacher_policy.pb.gz\")\n```\n\n### Extracting Network Statistics\n\n```python\nnet = load_weights(\"network.pb.gz\")\n\n# Get statistics from first encoder layer\nlayer = net.weights.encoder[0].mha.q_w.value\nprint(f\"Shape: {layer.shape}\")\nprint(f\"Mean: {layer.mean():.6f}\")\nprint(f\"Std: {layer.std():.6f}\")\nprint(f\"Min: {layer.min():.6f}\")\nprint(f\"Max: {layer.max():.6f}\")\n```\n\n### Format Conversion for Size Optimization\n\n```python\nfrom lczero_training.tools import load_weights, save_weights\n\n# Load network (any format)\nnet = load_weights(\"large_network.pb.gz\")\n\n# Save with more aggressive compression\nsave_weights(net, \"compressed_network.pb.gz\", encoding=\"LINEAR16\")\n```\n\n## Advanced Usage\n\n### Accessing Nested Structures\n\nThe weight wrapper provides Pythonic access to the nested protobuf structure:\n\n```python\nnet = load_weights(\"network.pb.gz\")\n\n# Access input embedding weights\ninput_weights = net.weights.input.weights.value\n\n# Access specific encoder layer\nencoder_layer_5 = net.weights.encoder[5]\n\n# Access multi-head attention components\nq_weights = net.weights.encoder[0].mha.q_w.value\nk_weights = net.weights.encoder[0].mha.k_w.value\nv_weights = net.weights.encoder[0].mha.v_w.value\n\n# Access policy head\npolicy_weights = net.weights.policy_heads.vanilla.ip_pol_w.value\n```\n\n### Complex Weighted Combinations\n\n```python\nnets = [load_weights(f\"checkpoint_{i}.pb.gz\") for i in range(5)]\nweights = [0.1, 0.15, 0.2, 0.25, 0.3]  # More weight to recent checkpoints\n\nresult = sum(net * w for net, w in zip(nets, weights))\nsave_weights(result, \"weighted_ensemble.pb.gz\")\n```\n\n### Selective Component Grafting\n\n```python\nbase = load_weights(\"base.pb.gz\")\ndonor = load_weights(\"donor.pb.gz\")\n\n# Replace only the first 10 encoder layers\nfor i in range(10):\n    base.weights.encoder[i] = donor.weights.encoder[i]\n\nsave_weights(base, \"partial_graft.pb.gz\")\n```\n\n### Working with NumPy Arrays Directly\n\nAll weight access returns NumPy arrays, allowing arbitrary transformations:\n\n```python\nnet = load_weights(\"network.pb.gz\")\n\n# Get layer weights\nlayer_weights = net.weights.encoder[0].mha.q_w.value\n\n# Apply custom transformation\nimport numpy as np\nlayer_weights_normalized = layer_weights / np.linalg.norm(layer_weights, axis=-1, keepdims=True)\n\n# Write back\nnet.weights.encoder[0].mha.q_w.value = layer_weights_normalized\n\nsave_weights(net, \"normalized.pb.gz\")\n```\n\n## Implementation Details\n\n### Lazy Loading\n\nWeights are decoded from their compressed format only when accessed, and cached for subsequent use. This makes the tool memory-efficient when working with large networks.\n\n### File Format Support\n\nBoth `.pb` (uncompressed protobuf) and `.pb.gz` (gzip-compressed) formats are supported. The tool automatically detects the format based on the file extension.\n\n### Arithmetic Semantics\n\n* Operations are element-wise across all matching layers\n* Networks must have compatible structures (same number of encoder layers, etc.)\n* Results are computed in float32 precision regardless of input encoding\n"
  },
  {
    "path": "init.sh",
    "content": "#!/usr/bin/env bash\n\nprotoc --proto_path=libs/lc0 --python_out=tf proto/net.proto\ntouch tf/proto/__init__.py\n"
  },
  {
    "path": "justfile",
    "content": "# List available commands\ndefault:\n    @just --list\n\n# Check if all C++ files in csrc/ are formatted according to clang-format\ncheck-cpp:\n    find csrc/ -name \"*.cpp\" -o -name \"*.cc\" -o -name \"*.cxx\" -o -name \"*.h\" -o -name \"*.hpp\" | xargs clang-format --dry-run --Werror\n\n# Format all C++ files in csrc/ using clang-format\nformat-cpp:\n    find csrc/ -name \"*.cpp\" -o -name \"*.cc\" -o -name \"*.cxx\" -o -name \"*.h\" -o -name \"*.hpp\" | xargs clang-format -i\n\n# Check if all protobuf files are formatted according to clang-format\ncheck-proto:\n    find proto/ -name \"*.proto\" | xargs clang-format --dry-run --Werror\n\n# Format all protobuf files using clang-format\nformat-proto:\n    find proto/ -name \"*.proto\" | xargs clang-format -i\n\n# Build Python protobuf files\nbuild-proto:\n    mkdir -p src/proto\n    touch src/proto/__init__.py\n    uv run python -m grpc_tools.protoc \\\n        --proto_path=. \\\n        --proto_path=libs/lc0 \\\n        --python_out=src/ \\\n        --pyi_out=src/ \\\n        proto/*.proto\n    uv run python -m grpc_tools.protoc \\\n        --proto_path=. \\\n        --proto_path=libs/lc0 \\\n        --python_out=src/ \\\n        --pyi_out=src/ \\\n        proto/net.proto \\\n        proto/onnx.proto \\\n        proto/hlo.proto\n\n# Check if all Python files in src/ are formatted according to ruff\ncheck-python:\n    uv run ruff check src/\n    uv run ruff check --select I src/\n    uv run ruff format --check src/\n    uv run mypy -p lczero_training --disallow-untyped-defs --disallow-incomplete-defs\n\n# Format all Python files in src/ using ruff\nformat-python:\n    uv run ruff check --fix --select I src/\n    uv run ruff format src/\n    uv run ruff check --fix src/\n\nformat: format-cpp format-proto format-python\n\n# Setup meson build directory with clang\nsetup-build:\n    CXX=clang++ CC=clang uv run meson setup build/release \\\n        --buildtype=release --native-file=native.ini\n\n# Build the project\nbuild:\n    uv run meson compile -C build/release/\n\n# Create symlink for the built extension module\nmksymlink:\n    cd src/lczero_training && ln -sfT ../../build/release/_lczero_training.cpython-*-x86_64-linux-gnu.so _lczero_training.so\n\n# Delete build/release, re-setup, rebuild, and re-link\nrebuild: && setup-build build mksymlink\n    rm -rf build/release\n\n# Run tests\ntest-cpp:\n    uv run meson test -C build/release/\n\ntest-python: \n    uv run pytest    \n\ntest: test-cpp test-python    \n\ncheck: check-cpp check-proto check-python\n\n# Run all checks (formatting, build, and tests)\npre-commit: build-proto check build test\n"
  },
  {
    "path": "meson.build",
    "content": "project(\n  'lczero-training',\n  'cpp',\n  version : '0.1',\n  meson_version : '>= 1.3.0',\n  default_options : [\n    'warning_level=3',\n    'cpp_std=c++20',\n    \n  ],\n)\n\n# Allow Clang nullability extensions when using Clang compiler\ncpp_compiler = meson.get_compiler('cpp')\nif cpp_compiler.get_id() == 'clang'\n  add_project_arguments('-Wno-nullability-extension', language : 'cpp')\nendif\n\n\n\n# External dependencies\nzlib_dep = dependency('zlib')\n\n# Python and PyBind11 dependencies for Python extension\npython3 = import('python').find_installation()\npybind11_dep = dependency('pybind11')\n\n# Abseil dependencies always resolved from the wrap subproject.\nabsl_proj = subproject('abseil-cpp')\nabsl_dep_specs = [\n  ['log', 'absl_log_dep'],\n  ['log_initialize', 'absl_log_dep'],\n  ['check', 'absl_log_dep'],\n  ['hash', 'absl_hash_dep'],\n  ['raw_hash_set', 'absl_container_dep'],\n  ['synchronization', 'absl_synchronization_dep'],\n  ['random_random', 'absl_random_dep'],\n  ['flags', 'absl_flags_dep'],\n  ['flags_parse', 'absl_flags_dep'],\n  ['throw_delegate', 'absl_base_dep'],\n]\nabsl_deps = {}\nforeach spec : absl_dep_specs\n  absl_deps += {\n    spec[0].underscorify() : absl_proj.get_variable(spec[1]).as_system()\n  }\nendforeach\n\n# Test dependencies\ngtest_dep = dependency('gtest').as_system()\ngtest_main_dep = dependency('gtest_main').as_system()\n\n# Gaviota tablebase subproject\ngaviota_dep = subproject('gaviotatb').get_variable('gaviotatb_dep')\n\n# Common dependency sets\nexternal_deps = [zlib_dep]\ncore_absl_deps = [\n  absl_deps['log'],\n  absl_deps['check'],\n  absl_deps['hash'],\n  absl_deps['raw_hash_set'],\n  absl_deps['synchronization'],\n  absl_deps['random_random'],\n]\nloader_deps = external_deps + core_absl_deps + [gaviota_dep]\nmain_deps = loader_deps + [absl_deps['log_initialize']]\n# test_deps will be defined after proto_dep\ncli_deps = [\n  absl_deps['log'],\n  absl_deps['log_initialize'],\n  absl_deps['flags'],\n  absl_deps['flags_parse'],\n]\n\n# Protobuf compilation setup\ncompile_proto = find_program('libs/lc0/scripts/compile_proto.py')\nproto_gen = generator(compile_proto, output: ['@BASENAME@.pb.h'],\n  arguments : [\n    '--proto_path=@CURRENT_SOURCE_DIR@',\n    '--cpp_out=@BUILD_DIR@',\n    '@INPUT@'])\n\nlc0_proto_gen = generator(compile_proto, output: ['@BASENAME@.pb.h'],\n  arguments : [\n    '--proto_path=' + join_paths(meson.current_source_dir(), 'libs/lc0'),\n    '--cpp_out=@BUILD_DIR@',\n    '@INPUT@'])\n\nincludes = include_directories('csrc', 'libs/lc0/src')\n\nadd_project_arguments('-DNO_PEXT', language : 'cpp')\n\nrescorer_files = [\n  'libs/lc0/src/chess/board.cc',\n  'libs/lc0/src/chess/gamestate.cc',\n  'libs/lc0/src/chess/position.cc',\n  'libs/lc0/src/neural/decoder.cc',\n  'libs/lc0/src/neural/encoder.cc',\n  'libs/lc0/src/trainingdata/reader.cc',\n  'libs/lc0/src/trainingdata/trainingdata.cc',\n  'libs/lc0/src/trainingdata/writer.cc',\n  'libs/lc0/src/utils/commandline.cc',\n  'libs/lc0/src/utils/configfile.cc',\n  'libs/lc0/src/utils/esc_codes.cc',\n  'libs/lc0/src/utils/optionsdict.cc',\n  'libs/lc0/src/utils/optionsparser.cc',\n  'libs/lc0/src/utils/random.cc',\n  'libs/lc0/src/utils/string.cc',\n]\n\nif host_machine.system() == 'windows'\n  rescorer_files += 'libs/lc0/src/utils/filesystem.win32.cc'\nelse\n  rescorer_files += 'libs/lc0/src/utils/filesystem.posix.cc'\nendif\n\nfiles = [\n  'csrc/loader/chunk_source/debug_chunk_source.cc',\n  'csrc/loader/chunk_source/rawfile_chunk_source.cc',\n  'csrc/loader/chunk_source/tar_chunk_source.cc',\n  'csrc/loader/data_loader_metrics.cc',\n  'csrc/loader/data_loader.cc',\n  'csrc/loader/stages/chunk_rescorer.cc',\n  'csrc/loader/stages/chunk_source_loader.cc',\n  'csrc/loader/stages/chunk_source_splitter.cc',\n  'csrc/loader/stages/chunk_unpacker.cc',\n  'csrc/loader/stages/file_path_provider.cc',\n  'csrc/loader/stages/join_stage.cc',\n  'csrc/loader/stages/position_sampling.cc',\n  'csrc/loader/stages/position_sampling.cc',\n  'csrc/loader/stages/shuffling_chunk_pool.cc',\n  'csrc/loader/stages/shuffling_frame_sampler.cc',\n  'csrc/loader/stages/simple_chunk_extractor.cc',\n  'csrc/loader/stages/stage_factory.cc',\n  'csrc/loader/stages/stage.cc',\n  'csrc/loader/stages/tensor_generator.cc',\n  'csrc/utils/gz.cc',\n  'csrc/utils/stream_shuffler.cc',\n  'csrc/utils/training_data_printer.cc',\n  'libs/lc0/src/syzygy/syzygy.cc',\n  'libs/lc0/src/trainingdata/rescorer.cc',\n  'libs/lc0/src/utils/files.cc',\n  'libs/lc0/src/utils/logging.cc',\n  'libs/lc0/src/utils/protomessage.cc',\n] + rescorer_files\n\n# Process protobuf files for C++\nproto_files = [\n  proto_gen.process('proto/data_loader_config.proto',\n    preserve_path_from : meson.current_source_dir()),\n  proto_gen.process('proto/training_metrics.proto',\n    preserve_path_from : meson.current_source_dir()),\n  proto_gen.process('proto/stage_control.proto',\n    preserve_path_from : meson.current_source_dir()),\n  lc0_proto_gen.process('libs/lc0/proto/net.proto',\n    preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0')),\n  lc0_proto_gen.process('libs/lc0/proto/onnx.proto',\n    preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0')),\n  lc0_proto_gen.process('libs/lc0/proto/hlo.proto',\n    preserve_path_from : join_paths(meson.current_source_dir(), 'libs/lc0'))\n]\n\n# Create a dependency for protobuf files that tests can use\nproto_dep = declare_dependency(sources: proto_files)\ntest_deps = [gtest_dep, gtest_main_dep, proto_dep]\n\n\nfiles += proto_files\n\nloader_lib = static_library(\n  'loader',\n  files,\n  include_directories : includes,\n  dependencies : loader_deps,\n)\n\nexe = executable(\n  'loader',\n  'csrc/loader/loader_main.cpp',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep],\n  link_with : loader_lib,\n)\n\nstream_shuffler_test = executable(\n  'stream_shuffler_test',\n  'csrc/utils/stream_shuffler_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['random_random']],\n  link_with : loader_lib,\n)\n\nqueue_test = executable(\n  'queue_test',\n  'csrc/utils/queue_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n)\n\nfile_path_provider_test = executable(\n  'file_path_provider_test',\n  'csrc/loader/stages/file_path_provider_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\nchunk_source_loader_test = executable(\n  'chunk_source_loader_test',\n  'csrc/loader/stages/chunk_source_loader_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization']],\n  link_with : loader_lib,\n)\n\nshuffling_chunk_pool_test = executable(\n  'shuffling_chunk_pool_test',\n  'csrc/loader/stages/shuffling_chunk_pool_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\n# simple_chunk_extractor_test = executable(\n#   'simple_chunk_extractor_test',\n#   'csrc/loader/stages/simple_chunk_extractor_test.cc',\n#   include_directories : includes,\n#   dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n#   link_with : loader_lib,\n# )\n\nchunk_rescorer_test = executable(\n  'chunk_rescorer_test',\n  'csrc/loader/stages/chunk_rescorer_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\nchunk_unpacker_test = executable(\n  'chunk_unpacker_test',\n  'csrc/loader/stages/chunk_unpacker_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\nshuffling_frame_sampler_test = executable(\n  'shuffling_frame_sampler_test',\n  'csrc/loader/stages/shuffling_frame_sampler_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['random_random']],\n  link_with : loader_lib,\n)\n\ntensor_test = executable(\n  'tensor_test',\n  'csrc/utils/tensor_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['throw_delegate']],\n)\n\ntensor_generator_test = executable(\n  'tensor_generator_test',\n  'csrc/loader/stages/tensor_generator_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization']],\n  link_with : loader_lib,\n)\n\nstats_test = executable(\n  'stats_test',\n  'csrc/utils/metrics/stats_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization']],\n  link_with : loader_lib,\n)\n\nload_metric_test = executable(\n  'load_metric_test',\n  'csrc/utils/metrics/load_metric_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization']],\n  link_with : loader_lib,\n)\n\nstage_factory_test = executable(\n  'stage_factory_test',\n  'csrc/loader/stages/stage_factory_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\ndata_loader_test = executable(\n  'data_loader_test',\n  'csrc/loader/data_loader_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\n\njoin_stage_test = executable(\n  'join_stage_test',\n  'csrc/loader/stages/join_stage_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\ntest('stream_shuffler_test', stream_shuffler_test)\ntest('queue_test', queue_test)\ntest('file_path_provider_test', file_path_provider_test)\ntest('chunk_source_loader_test', chunk_source_loader_test)\nchunk_source_splitter_test = executable(\n  'chunk_source_splitter_test',\n  'csrc/loader/stages/chunk_source_splitter_test.cc',\n  include_directories : includes,\n  dependencies : test_deps + [absl_deps['synchronization'], absl_deps['log']],\n  link_with : loader_lib,\n)\ntest('chunk_source_splitter_test', chunk_source_splitter_test)\ntest('shuffling_chunk_pool_test', shuffling_chunk_pool_test)\n# test('simple_chunk_extractor_test', simple_chunk_extractor_test)\ntest('chunk_rescorer_test', chunk_rescorer_test)\ntest('chunk_unpacker_test', chunk_unpacker_test)\ntest('shuffling_frame_sampler_test', shuffling_frame_sampler_test)\ntest('tensor_test', tensor_test)\ntest('tensor_generator_test', tensor_generator_test)\ntest('stats_test', stats_test)\ntest('load_metric_test', load_metric_test)\ntest('stage_factory_test', stage_factory_test)\ntest('data_loader_test', data_loader_test)\ntest('join_stage_test', join_stage_test)\n\nfile_path_provider_main = executable(\n  'file_path_provider_main',\n  'csrc/loader/stages/file_path_provider_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep],\n  link_with : loader_lib,\n)\n\ndump_chunk = executable(\n  'dump_chunk',\n  'csrc/tools/dump_chunk_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [zlib_dep],\n  link_with : loader_lib,\n)\n\nfilter_chunks = executable(\n  'filter_chunks',\n  'csrc/tools/filter_chunks_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep, zlib_dep],\n  link_with : loader_lib,\n)\n\nresult_distribution = executable(\n  'result_distribution',\n  'csrc/tools/result_distribution_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep, absl_deps['synchronization']],\n  link_with : loader_lib,\n)\n\nstartpos_policy_distribution = executable(\n  'startpos_policy_distribution',\n  'csrc/tools/startpos_policy_distribution_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep],\n  link_with : loader_lib,\n)\n\nrescore_chunk = executable(\n  'rescore_chunk',\n  'csrc/tools/rescore_chunk_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep],\n  link_with : loader_lib,\n)\n\nposition_weight_stats = executable(\n  'position_weight_stats',\n  'csrc/tools/position_weight_stats_main.cc',\n  include_directories : includes,\n  dependencies : cli_deps + [proto_dep],\n  link_with : loader_lib,\n)\n\n# Python extension module\npython3.extension_module(\n  '_lczero_training',\n  'csrc/loader/pybind_module.cc',\n  include_directories : includes,\n  dependencies : [pybind11_dep, proto_dep, absl_deps['log_initialize']] + loader_deps,\n  link_with : loader_lib,\n  install : true,\n  subdir : 'lczero_training',\n)\n"
  },
  {
    "path": "native.ini",
    "content": "[binaries]\npython = '@GLOBAL_SOURCE_ROOT@/.venv/bin/python'"
  },
  {
    "path": "proto/checkpoint_migration_config.proto",
    "content": "syntax = \"proto3\";\n\npackage lczero_training.proto;\n\nmessage CheckpointMigrationRule {\n  // Path in the old state pytree.\n  string from_path = 1;\n\n  // Path in the new state pytree.\n  string to_path = 2;\n}\n\nmessage CheckpointMigrationConfig {\n  repeated CheckpointMigrationRule rule = 1;\n}\n"
  },
  {
    "path": "proto/data_loader_config.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Configuration for output queue used by stages.\nmessage QueueConfig {\n  // Queue overflow behavior.\n  enum OverflowBehavior {\n    BLOCK = 0;\n    DROP_NEW = 1;\n    KEEP_NEWEST = 2;\n  }\n\n  // Optional name for the output (used for multi-output stages).\n  optional string name = 1;\n  // Size of the output queue.\n  optional uint64 queue_capacity = 2 [default = 4];\n  // Overflow behavior of the output queue.\n  optional OverflowBehavior overflow_behavior = 3;\n}\n\n// Configuration for file path provider that watches directories for new\n// training files. Maps to FilePathProviderOptions in\n// csrc/loader/chunk_feed/file_path_provider.h\nmessage FilePathProviderConfig {\n  // Path to directory containing training data files.\n  optional string directory = 1;\n  // Output queue configuration.\n  optional QueueConfig output = 2;\n}\n\n// Configuration for chunk source loader that converts file paths to chunk\n// sources. Maps to ChunkSourceLoaderOptions in\n// csrc/loader/chunk_feed/chunk_source_loader.h\nmessage ChunkSourceLoaderConfig {\n  enum FrameFormat {\n    V6TrainingData = 0;\n    V7TrainingData = 1;\n  }\n  // Number of worker threads for loading.\n  optional uint64 threads = 1 [default = 1];\n  // Output queue configuration.\n  optional QueueConfig output = 2;\n  // Training data frame format.\n  optional FrameFormat frame_format = 3;\n}\n\nmessage PositionSamplingConfig {\n  optional float diff_focus_q_weight = 1 [default = 0.0];\n  optional float diff_focus_pol_scale = 2 [default = 1.0];\n  optional float diff_focus_alpha = 3 [default = 1.0];\n  optional float diff_focus_beta = 4 [default = 0.0];\n  optional float diff_focus_gamma = 5 [default = 1.0];\n  optional float diff_focus_tau = 6 [default = 1.0];\n  optional float default_weight = 7 [default = 1.0];\n}\n\n// Configuration for shuffling chunk pool that manages chunk shuffling and\n// loading. Maps to ShufflingChunkPoolOptions in\n// csrc/loader/chunk_feed/shuffling_chunk_pool.h\nmessage ShufflingChunkPoolConfig {\n  // Size of the chunk shuffle buffer.\n  optional uint64 chunk_pool_size = 1 [default = 100000];\n  // Threads for ingesting new chunk sources (lightweight).\n  optional uint64 source_ingestion_threads = 3 [default = 1];\n  // Threads for loading chunk data.\n  optional uint64 chunk_loading_threads = 4 [default = 4];\n  // Output queue configuration.\n  optional QueueConfig output = 5;\n  // When set to a positive value, enable Hanse single-position sampling.\n  // Probability to accept a chunk is\n  //   min(1.0, num_records / hanse_sampling_threshold) ** hanse_sampling_gamma\n  // where num_records is the number of frames in the chunk.\n  optional uint64 hanse_sampling_threshold = 6;  // by default, disabled.\n  optional double hanse_sampling_gamma = 7 [default = 1.0];\n  optional PositionSamplingConfig position_sampling = 11;\n  // Optional output queue for cache hit frames.\n  optional QueueConfig cachehit_output = 8;\n  // Size of the position cache per chunk.\n  optional uint64 position_cache_size = 9;\n  // Threads for caching positions.\n  optional uint64 caching_threads = 10 [default = 1];\n}\n\n// Configuration for chunk rescorer that adjusts chunk metadata using\n// tablebases. Maps to ChunkRescorerOptions in\n// csrc/loader/stages/chunk_rescorer.h\nmessage ChunkRescorerConfig {\n  // Number of worker threads for rescoring.\n  optional uint64 threads = 1 [default = 1];\n  // Output queue configuration.\n  optional QueueConfig output = 2;\n  // Comma-separated list of Syzygy tablebase directories.\n  optional string syzygy_paths = 3;\n  // Policy reshaping temperature.\n  optional float dist_temp = 4 [default = 1.0];\n  // Policy offset applied during rescoring.\n  optional float dist_offset = 5 [default = 0.0];\n  // DTZ boost factor when policy adjustments apply.\n  optional float dtz_boost = 6 [default = 0.0];\n  // Optional conversion target for input format (-1 keeps original).\n  optional int32 new_input_format = 7 [default = -1];\n  // Optional deblunder threshold for policy adjustments.\n  optional float deblunder_threshold = 8;\n  // Optional deblunder width controlling smoothing around the threshold.\n  optional float deblunder_width = 9;\n  // Comma-separated list of Gaviota tablebase directories.\n  optional string gaviota_paths = 11;\n  // Soft threshold for st_q when converting v6 to v7.\n  optional float st_q_theta = 12 [default = 0.8333333333];\n}\n\n// Configuration for chunk unpacker that extracts frames from packed chunks.\n// Maps to ChunkUnpackerOptions in csrc/loader/chunk_feed/chunk_unpacker.h\nmessage ChunkUnpackerConfig {\n  // Number of worker threads for unpacking.\n  optional uint64 threads = 1 [default = 1];\n  // Probability of sampling each position within a chunk.\n  optional float position_sampling_rate = 2 [default = 1.0];\n  // Number of positions to take from each chunk.\n  optional uint32 position_count = 3;\n  // Output queue configuration.\n  optional QueueConfig output = 4;\n  // Number of positions to prefetch for caching.\n  optional uint32 prefetch_count = 5;\n  // Optional output queue for prefetch cache requests.\n  optional QueueConfig prefetch_output = 6;\n\n  optional PositionSamplingConfig position_sampling = 7;\n}\n\n// Configuration for shuffling frame sampler using reservoir sampling.\n// Maps to ShufflingFrameSamplerOptions in csrc/loader/shuffling_frame_sampler.h\nmessage ShufflingFrameSamplerConfig {\n  // Number of worker threads.\n  optional uint64 threads = 1 [default = 1];\n  // Size of sampling reservoir per thread.\n  optional uint64 reservoir_size_per_thread = 2 [default = 1000000];\n  // Output queue configuration.\n  optional QueueConfig output = 3;\n}\n\n// Configuration for tensor generator that converts frames to batched tensors.\n// Maps to TensorGeneratorOptions in csrc/loader/tensor_generator.h\nmessage TensorGeneratorConfig {\n  // Number of worker threads for tensor generation.\n  optional uint64 threads = 1 [default = 1];\n  // Batch size for tensor generation.\n  optional uint64 batch_size = 2 [default = 1024];\n  // Output queue configuration.\n  optional QueueConfig output = 3;\n}\n\n// Configuration for a stage that splits incoming ChunkSources into several\n// outputs based on a deterministic hash of (chunk source name, index within\n// source). For each configured output, provides a name, weight that controls\n// the proportion of indices assigned to it, and queue parameters.\nmessage ChunkSourceSplitterConfig {\n  // List of output queue configurations.\n  repeated QueueConfig output = 1;\n  // Relative weights for hash-based assignment (parallel to output).\n  repeated uint64 weight = 2;\n}\n\n// Configuration for simple chunk shuffler that processes one chunk source at a\n// time and outputs all chunks in shuffled order.\nmessage SimpleChunkExtractorConfig {\n  // Output queue configuration.\n  optional QueueConfig output = 1;\n}\n\n// Configuration for join stage that merges multiple input streams.\nmessage JoinPositionsConfig {\n  // Output queue configuration.\n  optional QueueConfig output = 1;\n}\n\n// Stage-level configuration providing a name and stage-specific options.\nmessage StageConfig {\n  // Unique name used to reference the stage output.\n  optional string name = 1;\n  // Names of upstream stages providing input to this stage.\n  repeated string input = 2;\n  // Field 3 reserved for future output configuration.\n  optional FilePathProviderConfig file_path_provider = 4;\n  optional ChunkSourceLoaderConfig chunk_source_loader = 5;\n  optional ShufflingChunkPoolConfig shuffling_chunk_pool = 6;\n  optional ChunkUnpackerConfig chunk_unpacker = 7;\n  optional ShufflingFrameSamplerConfig shuffling_frame_sampler = 8;\n  optional TensorGeneratorConfig tensor_generator = 9;\n  optional ChunkRescorerConfig chunk_rescorer = 10;\n  optional ChunkSourceSplitterConfig chunk_source_splitter = 11;\n  optional SimpleChunkExtractorConfig simple_chunk_extractor = 12;\n  optional JoinPositionsConfig join_positions = 13;\n}\n\n// Main configuration class for the DataLoader containing all component\n// configurations.\nmessage DataLoaderConfig {\n  // Ordered list of stage configurations comprising the pipeline.\n  repeated StageConfig stage = 1;\n  // Exposed outputs, each with an optional alias used by clients to fetch\n  // data. Expected format: \"alias:stage.output\", where \"alias:\" and \".output\"\n  // are optional.\n  repeated string output = 2;\n}\n"
  },
  {
    "path": "proto/export_config.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Configuration for model export settings.\nmessage ExportConfig {\n  // Destination filenames where exported models will be saved.\n  // Supports formatting with {datetime} (YYYYMMDD-HHMMSS) and {step} (int).\n  // Example: \"/path/to/models/lc0-{datetime}-{step:08d}.pb.gz\"\n  repeated string destination_filename = 1;\n  // Training run ID for uploading to training website. Only uploads when set.\n  optional int32 upload_training_run = 2;\n  // Whether to export the SWA model instead of the main model.\n  optional bool export_swa_model = 3;\n}\n"
  },
  {
    "path": "proto/metrics_config.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Sentinel message for training batch sample type.\nmessage TrainingBatch {}\n\n// Configuration for weight-based metrics.\nmessage WeightsMetric {\n  optional bool rms = 1;\n}\n\n// Configuration for individual metric collection.\nmessage MetricConfig {\n  // Name of the metric.\n  optional string name = 1;\n  // Whether to use the SWA model for this metric.\n  optional bool use_swa_model = 2;\n  // Period for metric collection.\n  optional int32 period = 3 [default = 1];\n  // Whether to use global steps instead of epoch-relative steps.\n  optional bool use_global_steps = 4 [default = false];\n  // Whether to collect metric after epoch completion.\n  optional bool after_epoch = 5;\n\n  // Source of data for metric calculation.\n  oneof sample {\n    // Use training batch data.\n    TrainingBatch training_batch = 6;\n    // Name of dataloader output to use.\n    string dataloader_output = 7;\n    // Path to numpy tensor file (.npz).\n    string npz_filename = 8;\n    // Weight-based metrics.\n    WeightsMetric weights = 9;\n  }\n}\n\n// Configuration for metrics collection and export.\nmessage MetricsConfig {\n  // Directory path where TensorBoard event files will be written.\n  optional string tensorboard_path = 1;\n  // List of metrics to collect during training.\n  repeated MetricConfig metric = 2;\n}\n"
  },
  {
    "path": "proto/model_config.proto",
    "content": "syntax = \"proto2\";\n\nimport \"proto/net.proto\";\nimport \"proto/hlo.proto\";\n\npackage lczero.training;\n\nmessage ModelConfig {\n  optional DefaultsConfig defaults = 4;\n  optional EmbeddingConfig embedding = 5;\n  optional EncoderConfig encoder = 6;\n  optional uint32 shared_policy_embedding_size = 7;\n  repeated PolicyHeadConfig policy_head = 8;\n  repeated ValueHeadConfig value_head = 9;\n  repeated MovesLeftHeadConfig movesleft_head = 10;\n}\n\nmessage DefaultsConfig {\n  optional pblczero.XlaShapeProto.Type compute_dtype = 2;\n  optional pblczero.NetworkFormat.ActivationFunction activation = 4;\n  optional pblczero.NetworkFormat.ActivationFunction ffn_activation = 5;\n}\n\nmessage EmbeddingConfig {\n  optional uint32 dense_size = 1;\n  optional uint32 embedding_size = 2;\n  optional uint32 dff = 3;\n}\n\nmessage EncoderConfig {\n  optional uint32 num_blocks = 3;\n  optional uint32 dff = 4;\n  optional uint32 d_model = 5;\n  optional uint32 heads = 6;\n  optional SmolgenConfig smolgen = 9;\n}\n\nmessage SmolgenConfig {\n  optional uint32 hidden_channels = 1;\n  optional uint32 hidden_size = 2;\n  optional uint32 gen_size = 3;\n  optional pblczero.NetworkFormat.ActivationFunction activation = 5;\n}\n\nmessage PolicyHeadConfig {\n  optional string name = 1;\n  optional uint32 embedding_size = 2;\n  optional uint32 d_model = 3;\n}\n\nmessage ValueHeadConfig {\n  optional string name = 1;\n  optional uint32 num_channels = 2;\n  optional bool has_error_output = 3;\n  optional uint32 num_categorical_buckets = 4;\n}\n\nmessage MovesLeftHeadConfig {\n  optional string name = 1;\n  optional uint32 num_channels = 2;\n}\n"
  },
  {
    "path": "proto/root_config.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero.training;\n\nimport \"proto/data_loader_config.proto\";\nimport \"proto/model_config.proto\";\nimport \"proto/training_config.proto\";\nimport \"proto/metrics_config.proto\";\nimport \"proto/export_config.proto\";\n\n// Root configuration message containing all subsystem configurations.\nmessage RootConfig {\n  // Name identifier for this configuration\n  optional string name = 1;\n  // Optional log filename for file-based logging\n  optional string log_filename = 2;\n  // Data loader configuration\n  optional DataLoaderConfig data_loader = 3;\n  // Training configuration\n  optional TrainingConfig training = 4;\n  // Model configuration\n  optional ModelConfig model = 5;\n  // Metrics configuration\n  optional MetricsConfig metrics = 6;\n  // Export configuration\n  optional ExportConfig export = 7;\n}"
  },
  {
    "path": "proto/stage_control.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero.training;\n\nmessage ShufflingChunkPoolControlRequest {\n  optional bool reset_chunk_anchor = 1;\n  optional string set_chunk_anchor = 2;\n}\n\nmessage StageControlRequest {\n  optional ShufflingChunkPoolControlRequest chunk_pool_request = 1;\n}\n\nmessage ShufflingChunkPoolControlResponse {\n  optional string chunk_anchor = 1;\n  optional int32 chunks_since_anchor = 2;\n}\n\nmessage StageControlResponse {\n  optional ShufflingChunkPoolControlResponse chunk_pool_response = 1;\n}\n"
  },
  {
    "path": "proto/training_config.proto",
    "content": "syntax = \"proto3\";\n\npackage lczero.training;\n\n// Configuration for training algorithm and parameters.\nmessage TrainingConfig {\n  ScheduleConfig schedule = 1;\n  repeated LrSchedule lr_schedule = 2;\n  CheckpointConfig checkpoint = 3;\n  OptimizerConfig optimizer = 4;\n  LossConfig losses = 5;\n  // Maximum gradient norm; set to 0 or omit to disable clipping.\n  float max_grad_norm = 6;\n  // Stochastic Weight Averaging (SWA) configuration. When absent, SWA is\n  // disabled.\n  SWAConfig swa = 7;\n}\n\nmessage ScheduleConfig {\n  int32 steps_per_network = 1;\n  int32 chunks_per_network = 2;\n}\n\nmessage OptimizerConfig {\n  oneof optimizer_type {\n    NadamwOptimizerConfig nadamw = 1;\n    NadamOptimizerConfig nadam = 2;\n    SgdOptimizerConfig sgd = 3;\n  }\n  // When set, weights selected by this selector are frozen (no gradient\n  // updates).\n  WeightsSelector freeze_selector = 4;\n}\n\nmessage LrSchedule {\n  // Optimizer step when this schedule becomes active.\n  int32 starting_step = 1;\n  // Duration of each interval while this schedule is active. Last entry may be\n  // zero to indicate an open-ended tail.\n  repeated uint32 duration_steps = 2;\n  // Learning rate at the beginning of each interval.\n  repeated float lr = 3;\n  enum Transition {\n    CONSTANT = 0;\n    LINEAR = 1;\n    COSINE = 2;\n  }\n  // Transition type to use for each interval. Missing entries default to\n  // CONSTANT.\n  repeated Transition transition = 4;\n  // When true this schedule loops after finishing the final interval.\n  bool loop = 5;\n}\n\nmessage NadamwOptimizerConfig {\n  float beta_1 = 1;\n  float beta_2 = 2;\n  float epsilon = 3;\n  float weight_decay = 4;\n  // Selector for weight decay. True = apply decay to this category.\n  WeightsSelector decay_selector = 5;\n}\n\nmessage NadamOptimizerConfig {\n  float beta_1 = 1;\n  float beta_2 = 2;\n  float epsilon = 3;\n}\n\nmessage SgdOptimizerConfig {\n  float momentum = 1;\n  bool nesterov = 2;\n}\n\nmessage WeightsSelectorRule {\n  // Glob pattern matched against weight path.\n  string match = 1;\n  // True = include; false = exclude.\n  bool include = 2;\n}\n\n// Selector for which weight categories to include.\n// Rules are evaluated in order; the first matching rule applies.\n// If no rule matches, otherwise_include is used.\nmessage WeightsSelector {\n  repeated WeightsSelectorRule rule = 1;\n  bool otherwise_include = 2;\n}\n\nmessage CheckpointConfig {\n  string path = 1;\n  int32 max_to_keep = 2;\n}\n\nmessage LossConfig {\n  repeated PolicyLossConfig policy = 1;\n  repeated ValueLossConfig value = 2;\n  repeated MovesLeftLossConfig movesleft = 3;\n  repeated ValueErrorLossConfig value_error = 4;\n  repeated ValueCategoricalLossConfig value_categorical = 5;\n  repeated RegularizationLossConfig regularization = 6;\n}\n\nmessage RegularizationLossConfig {\n  enum RegularizationType {\n    L2 = 0;\n  }\n  RegularizationType type = 1;\n  string metric_name = 2;\n  float weight = 3;\n  // Selector for which weights to regularize. True = include in regularization.\n  WeightsSelector selector = 4;\n}\n\nenum ValueType {\n  RESULT = 0;\n  BEST = 1;\n  PLAYED = 2;\n  ORIG = 3;\n  ROOT = 4;\n  ST = 5;\n}\n\nmessage OptimisticPolicyConfig {\n  // Name of the value head to use for optimism computation.\n  string value_head_name = 1;\n  // Which value type to use from the training sample.\n  ValueType value_type = 2;\n  // Z-score threshold for optimism (e.g. 2.0).\n  float strength = 3;\n  // Epsilon for numerical stability (e.g. 1e-5).\n  float eps = 4;\n  // Sigmoid scaling factor (e.g. 3.0).\n  float alpha = 5;\n  // If false, prevent gradients from flowing to value and error heads.\n  bool propagate_value_gradients = 6;\n}\n\nmessage PolicyLossConfig {\n  string head_name = 1;\n  string metric_name = 2;\n  float weight = 3;\n  enum IllegalMoveHandling {\n    TRAIN_TO_ZERO = 0;\n    MASK = 1;\n  }\n  IllegalMoveHandling illegal_moves = 4;\n  enum LossType {\n    LOSS_TYPE_UNSPECIFIED = 0;\n    CROSS_ENTROPY = 1;\n    KL = 2;\n  }\n  // Selects which policy loss implementation to use. Must be specified.\n  LossType type = 5;\n  // Soft policy temperature applied before normalizing targets for KL loss.\n  // Values <= 0 disable the adjustment and keep raw targets.\n  float temperature = 6;\n  OptimisticPolicyConfig optimistic = 7;\n}\n\nmessage ValueLossConfig {\n  string head_name = 1;\n  string metric_name = 2;\n  float weight = 3;\n  ValueType value_type = 4;\n}\n\nmessage MovesLeftLossConfig {\n  string head_name = 1;\n  string metric_name = 2;\n  float weight = 3;\n  ValueType value_type = 4;\n}\n\nmessage ValueErrorLossConfig {\n  string head_name = 1;\n  string metric_name = 2;\n  float weight = 3;\n  ValueType value_type = 4;\n  bool propagate_value_gradients = 5;\n}\n\nmessage ValueCategoricalLossConfig {\n  string head_name = 1;\n  string metric_name = 2;\n  float weight = 3;\n  ValueType value_type = 4;\n}\n\n// Stochastic Weight Averaging configuration.\nmessage SWAConfig {\n  // Number of steps between SWA updates.\n  uint32 period_steps = 1;\n  // Maximum effective number of model snapshots to average.\n  uint32 num_averages = 2;\n}\n"
  },
  {
    "path": "proto/training_metrics.proto",
    "content": "syntax = \"proto2\";\n\npackage lczero;\n\n// Load metric that accumulates seconds of load time.\n// Separate proto to support LoadMetricUpdaterProto functionality.\nmessage LoadMetricProto {\n  optional string name = 1;\n  optional double load_seconds = 2 [default = 0.0];\n  optional double total_seconds = 3 [default = 0.0];\n}\n\n// Statistics metric for integer values.\nmessage StatisticsProtoInt64 {\n  optional int64 min = 1 [default = 9223372036854775807];\n  optional int64 max = 2 [default = -9223372036854775807];\n  optional int64 sum = 3 [default = 0];\n  optional int64 count = 4 [default = 0];\n  optional int64 latest = 5 [default = 0];\n}\n\n// Statistics metric for double values.\nmessage StatisticsProtoDouble {\n  optional string name = 1;\n  optional double min = 2 [default = 1.7976931348623157e+308];\n  optional double max = 3 [default = -1.7976931348623157e+308];\n  optional double sum = 4 [default = 0.0];\n  optional int64 count = 5 [default = 0];\n  optional double latest = 6 [default = 0.0];\n}\n\n// Metrics for queue performance monitoring.\nmessage QueueMetricProto {\n  optional string name = 1;\n  optional uint64 put_count = 2 [default = 0];\n  optional uint64 get_count = 3 [default = 0];\n  optional uint64 drop_count = 4 [default = 0];\n  optional StatisticsProtoInt64 queue_fullness = 5;\n  optional uint64 queue_capacity = 6 [default = 0];\n}\n\nmessage CountMetricProto {\n  optional string name = 1;\n  optional uint64 count = 2 [default = 0];\n}\n\n// Gauge metric for values that represent current state.\nmessage GaugeMetricProto {\n  optional string name = 1;\n  optional uint64 value = 2 [default = 0];\n  optional uint64 capacity = 3;\n}\n\n// Top-level metrics for the DataLoader.\nmessage StageMetricProto {\n  optional string name = 1;\n  repeated LoadMetricProto load_metrics = 3;\n  repeated QueueMetricProto queue_metrics = 4;\n  repeated CountMetricProto count_metrics = 5;\n  repeated GaugeMetricProto gauge_metrics = 11;\n  repeated StatisticsProtoDouble statistics_metrics = 12;\n  optional string last_chunk_key = 8;\n  optional string anchor = 9;\n}\n\nmessage DataLoaderMetricsProto {\n  repeated StageMetricProto stage_metrics = 1;\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"lczero-training\"\nversion = \"0.1.0\"\ndescription = \"Training scripts and data loading for Leela Chess Zero\"\nauthors = [{ name = \"Leela Chess Zero Team\" }]\nreadme = \"README.md\"\nrequires-python = \">=3.13\"\ndependencies = [\n    \"pybind11>=2.10.0\",\n    \"numpy>=1.24.0\",\n    \"textual[dev]>=0.47.0\",\n    \"protobuf==6.33.5\",\n    \"mypy>=1.17.1\",\n    \"pytest>=8.4.1\",\n    \"anyio>=4.10.0\",\n    \"jax[cuda12]==0.9.1\",\n    \"flax>=0.11.1\",\n    \"optax>=0.2.5\",\n    \"orbax-checkpoint>=0.11.23\",\n    \"python-dotenv>=1.1.1\",\n    \"requests[socks]>=2.32.5\",\n    \"matplotlib>=3.10.6\",\n    \"jaxlib==0.9.1\",\n    \"onnxruntime>=1.23.1\",\n    \"onnxruntime-gpu>=1.23.0\",\n    \"tensorboardx>=2.6.4\",\n    \"graphviz>=0.21\",\n    \"meson>=1.10.1\",\n    \"ruff>=0.15.4\",\n]\n\n[project.scripts]\nlc0-leela2jax = \"lczero_training.commands.leela2jax:main\"\nlc0-jax2leela = \"lczero_training.commands.jax2leela:main\"\nlc0-describe = \"lczero_training.commands.describe_training:main\"\nlc0-test-dataloader = \"lczero_training.commands.test_dataloader:main\"\nlc0-dataloader-viz = \"lczero_training.commands.dataloader_viz:main\"\nlc0-overfit = \"lczero_training.commands.overfit:main\"\nlc0-init = \"lczero_training.commands.training_init:main\"\nlc0-eval = \"lczero_training.commands.training_eval:main\"\nlc0-tune-lr = \"lczero_training.commands.tune_lr:main\"\nlc0-migrate-checkpoint = \"lczero_training.commands.migrate_checkpoint:main\"\nlc0-backfill-metrics = \"lczero_training.commands.backfill_metrics:main\"\nlc0-train = \"lczero_training.commands.train:main\"\nlc0-daemon = \"lczero_training.commands.daemon:main\"\nlc0-tui = \"lczero_training.commands.tui:main\"\nlc0-weights = \"lczero_training.commands.weights_tool:main\"\n\n[project.optional-dependencies]\ndev = [\n    \"pytest>=7.0.0\",\n    \"mypy>=1.0.0\",\n    \"typing-extensions>=4.0.0\",\n    \"mypy-extensions>=0.4.0\",\n    \"pathspec>=0.11.0\",\n]\n\n[build-system]\nrequires = [\"setuptools>=64\", \"pybind11>=2.10.0\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n\n[tool.setuptools.packages.find]\nwhere = [\"src\"]\n\n[tool.setuptools.package-data]\n\"*\" = [\"*.so\", \"*.dll\", \"*.dylib\"]\n\n[dependency-groups]\ndev = [\n    \"grpcio-tools>=1.76.0\",\n    \"mypy-protobuf>=3.6.0\",\n    \"textual-dev>=1.7.0\",\n    \"types-protobuf>=6.30.2.20250809\",\n    \"types-requests>=2.32.4.20250913\",\n]\n\n[tool.mypy]\nmypy_path = \"src\"\n# TEMPORARY: Disable misc errors due to protobuf bug in PR #23156\n# (https://github.com/protocolbuffers/protobuf/pull/23156)\n# Revert when protobuf fixes __slots__ = () in generated .pyi files\ndisable_error_code = [\"misc\"]\n\n[[tool.mypy.overrides]]\nmodule = \"orbax.*\"\nignore_missing_imports = true\n\n[[tool.mypy.overrides]]\nmodule = \"optax\"\nignore_missing_imports = true\n\n[[tool.mypy.overrides]]\nmodule = \"onnxruntime\"\nignore_missing_imports = true\n\n[[tool.mypy.overrides]]\nmodule = \"tensorboardX\"\nignore_missing_imports = true\n\n\n[tool.pytest.ini_options]\ntestpaths = [\"src\"]\npython_files = [\"test_*.py\", \"*_test.py\"]\npythonpath = [\"src\"]\naddopts = \"-v\"\n\n[tool.ruff]\nexclude = [\"*_pb2.py*\"]\nline-length = 80\n"
  },
  {
    "path": "scripts/diff.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"training.*.gz\")):\n            ids.append(int(os.path.basename(f).split('.')[-2]))\n    ids.sort(reverse=True)\n    return ids\n\n\ndef main(argv):\n    a = get_sorted_chunk_ids([argv.input])\n    b = get_sorted_chunk_ids(argv.dirs)\n    n = min(argv.wsize, len(a))\n    diff = set(a[:n]) - set(b)\n    for i in sorted(diff):\n        print('training.{}.gz'.format(i))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Print diffset of input dir and output dirs.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n    argparser.add_argument('-w', '--wsize', type=int, help='window size')\n    argparser.add_argument('dirs', nargs='+', help='output directories')\n\n    main(argparser.parse_args())\n"
  },
  {
    "path": "scripts/fixorder.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"training.*.gz\")):\n            ids.append(int(os.path.basename(f).split('.')[-2]))\n    ids.sort()\n    return ids\n\n\ndef main(argv):\n    a = get_sorted_chunk_ids([argv.input])\n    for i in a:\n        os.utime(os.path.join(argv.input, \"training.{}.gz\".format(i)), None)\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Change modification time on training files to match their numeric order.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n\n    main(argparser.parse_args())\n"
  },
  {
    "path": "scripts/init.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nWINDOWSIZE=80000\nROOT=\"/work/lc0\"\n\necho \"Cleaning up data directory\"\nrm -rf $ROOT/data\nmkdir -v $ROOT/data\n\necho \"Hard link $WINDOWSIZE seed files in $ROOT/data\"\ni=1\nfor file in $(find $ROOT/seed/data-* -name '*.gz' | shuf)\ndo\n  ln $file $ROOT/data/training.$i.gz\n  if [[ $i = $WINDOWSIZE ]]\n  then\n    break\n  fi\n  let i=\"i + 1\"\ndone\n\necho \"Set $HOME/.lc0.dat to $WINDOWSIZE\"\necho $WINDOWSIZE > $HOME/.lc0.dat\n\nrm -rf $ROOT/split\nmkdir -vp $ROOT/split/{test,train}\nlet testsize=\"$WINDOWSIZE / 10\"\nlet trainsize=\"$WINDOWSIZE - $testsize\"\necho \"Create $ROOT/split/test ($testsize) and $ROOT/split/train ($trainsize)\"\nls -1 -U $ROOT/data | head -n $trainsize | xargs -I{} ln $ROOT/data/{} $ROOT/split/train/{}\nls -1 -U $ROOT/data | tail -n $testsize | xargs -I{} ln $ROOT/data/{} $ROOT/split/test/{}\n"
  },
  {
    "path": "scripts/initsplit.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"training.*.gz\")):\n            ids.append(int(os.path.basename(f).split('.')[-2]))\n    ids.sort(reverse=True)\n    return ids\n\n\ndef main(argv):\n    a = get_sorted_chunk_ids([argv.input])\n    n = min(argv.wsize, len(a))\n    for i in sorted(a[:n]):\n        if i % 100 >= 90:\n            os.link(os.path.join(argv.input, \"training.{}.gz\".format(i)),\n                    os.path.join(argv.output, \"test/training.{}.gz\".format(i)))\n        else:\n            os.link(\n                os.path.join(argv.input, \"training.{}.gz\".format(i)),\n                os.path.join(argv.output, \"train/training.{}.gz\".format(i)))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Link input to test/train subdirectories of output in 10:90 ratio.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n    argparser.add_argument(\n        '-w',\n        '--wsize',\n        type=int,\n        help=\n        'window size - should be padded a bit to ensure both sides of split exceed fraction of target'\n    )\n    argparser.add_argument('-o', '--output', type=str, help='output directory')\n\n    main(argparser.parse_args())\n"
  },
  {
    "path": "scripts/inittrainingname.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"game_*.gz\")):\n            if os.path.basename(f) == \"game_000000.gz\":\n                ids.append((0, f))\n            else:\n                ids.append((int(\n                    os.path.basename(f).split('.')[-2].split('_')[-1].lstrip(\n                        \"0\")), f))\n    ids.sort()\n    return ids\n\n\ndef main(argv):\n    a = get_sorted_chunk_ids([argv.input])\n    for (i, f) in a:\n        os.rename(\n            f, os.path.join(argv.input,\n                            \"training.{}.gz\".format(i + argv.base)))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Rename files generated by lc0 to have training names.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n    argparser.add_argument('-b',\n                           '--base',\n                           type=int,\n                           default=0,\n                           help='base value for names')\n\n    main(argparser.parse_args())\n"
  },
  {
    "path": "scripts/pack.py",
    "content": "#!/usr/bin/env python3\n\nimport glob\nimport os\nimport argparse\nimport gzip\nimport bz2\nimport struct\nimport numpy as np\nfrom multiprocessing import Pool\n\nRECORD_SIZE = 8276\n\n\ndef get_uncompressed_size(filename):\n    with open(filename, 'rb') as f:\n        f.seek(-4, 2)\n        return struct.unpack('I', f.read(4))[0]\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"training.*.gz\")):\n            ids.append(int(os.path.basename(f).split('.')[-2]))\n    ids.sort()\n    return ids\n\n\ndef pack(ids):\n    plies = []\n    fout_name = os.path.join(argv.output, '{}-{}.bz2'.format(ids[0], ids[-1]))\n    with bz2.open(fout_name, 'xb') as fout:\n        for tid in ids:\n            fin_name = os.path.join(argv.input, 'training.{}.gz'.format(tid))\n            plies.append(get_uncompressed_size(fin_name) // RECORD_SIZE)\n            with gzip.open(fin_name, 'rb') as fin:\n                fout.write(fin.read())\n            if argv.remove:\n                os.remove(fin_name)\n\n        plylist = np.array(plies, dtype=np.int16)\n        size = struct.pack('I', len(plylist) * 2)\n        fout.write(plylist.tobytes())\n        fout.write(size)\n\n    print(\"Written '{}' {} records\".format(fout_name, np.sum(plies)))\n\n\ndef main():\n    if not os.path.exists(argv.output):\n        os.makedirs(argv.output)\n        print(\"Created directory '{}'\".format(argv.output))\n\n    ids = get_sorted_chunk_ids([argv.input])\n    n = len(ids) // argv.number\n    m = argv.number\n    print(\"Processing {} ids, {} - {} ({}x{})\".format(len(ids), ids[0],\n                                                      ids[-1], n, m))\n    packs = [ids[i * m:i * m + m] for i in range(n)]\n\n    # add remaining ids to last pack\n    packs[-1] += ids[n * m + m:]\n\n    with Pool() as pool:\n        pool.map(pack, packs)\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Repack training.*.gz files in batches of bz2 format.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n    argparser.add_argument('-o', '--output', type=str, help='output directory')\n    argparser.add_argument('-r',\n                           '--remove',\n                           action='store_true',\n                           help='remove input files while processing')\n    argparser.add_argument('-n',\n                           '--number',\n                           type=int,\n                           default=1000,\n                           help='number of games to repack per bz2 package')\n    argv = argparser.parse_args()\n\n    main()\n"
  },
  {
    "path": "scripts/purge.py",
    "content": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n    ids = []\n    for d in dirs:\n        for f in glob.glob(os.path.join(d, \"training.*.gz\")):\n            ids.append(int(os.path.basename(f).split('.')[-2]))\n    ids.sort(reverse=True)\n    return ids\n\n\ndef main(argv):\n    a = get_sorted_chunk_ids([argv.input])\n    n = min(argv.wsize, len(a))\n    for i in a[n:]:\n        os.remove(os.path.join(argv.input, \"training.{}.gz\".format(i)))\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Delete from input not in window.')\n    argparser.add_argument('-i', '--input', type=str, help='input directory')\n    argparser.add_argument('-w', '--wsize', type=int, help='window size')\n\n    main(argparser.parse_args())\n"
  },
  {
    "path": "scripts/rescore.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nROOT=\"/work/lc0/dev2\"\nRESCORER=\"$HOME/bin/rescorer\"\n\nfunction usage()\n{\n  echo \"Rescores stuff\"\n  echo \"\"\n  echo \"./rescore.sh\"\n  echo \"  -h --help\"\n  echo \"\"\n  echo \"Example: ./rescore.sh\"\n  echo \"\"\n}\n\nwhile [ \"$1\" != \"\" ]\ndo\n  PARAM=`echo $1 | awk -F= '{print $1}'`\n  VALUE=`echo $1 | awk -F= '{print $2}'`\n  case $PARAM in\n    -h | --help)\n      usage\n      exit\n      ;;\n    *)\n      echo \"ERROR: unknown parameter \\\"$PARAM\\\"\"\n      usage\n      exit 1\n      ;;\n  esac\n  shift\ndone\n\nrescore() {\n  unbuffer $RESCORER rescore --threads=4 --syzygy-paths=/work/lc0/syzygy/:/wdl/syzygy/wdl/:/wdl/syzygy/dtz/ --input=\"$ROOT/data-staged\" --output=\"$ROOT/data-rescored\" 2>&1 | tee \"$ROOT/rescore-logs/$(date +%Y%m%d-%H%M%S).log\"\n}\n\nwhile true\ndo\n  rescore\n  echo -n \".\"\n  sleep 10\ndone\n"
  },
  {
    "path": "scripts/shuffle.py",
    "content": "#!/usr/bin/python3\nimport gzip\nimport sys\nimport glob\nimport os\nimport random\nfrom multiprocessing import Pool\nimport tqdm\n\nmerge_files = 100\nprocesses = 8\nshuffle = True\nrecord_length = 8292\n\n\ndef split(a, n):\n    k, m = divmod(len(a), n)\n    return [a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]\n\n\ndef positions(chunk):\n    pos = []\n    i = 0\n    while record_length * i < len(chunk):\n        pos.append(chunk[record_length * i:record_length * (i + 1)])\n        i += 1\n    return pos\n\n\ndef shuffle(files):\n    data = []\n    for filename in files:\n        with gzip.open(filename, 'rb') as f:\n            data.extend(positions(f.read()))\n    if shuffle:\n        random.shuffle(data)\n    for d in data:\n        if d[0] != 0x04:\n            print(files)\n            raise ValueError('Wrong training data format, not V4')\n    new_file = list(os.path.splitext(files[0]))\n    new_file[0] += '_shuffled'\n    new_file = ''.join(new_file)\n    new_file_temp = new_file + '.temp'\n    with gzip.open(new_file_temp, 'wb', compresslevel=9) as f:\n        for d in data:\n            f.write(d)\n    # For interrupt safety, make sure not to write partial chunks.\n    os.rename(new_file_temp, new_file)\n    for filename in files:\n        os.remove(filename)\n\n\nif __name__ == \"__main__\":\n    if len(sys.argv) != 2:\n        print('Expected one argument, got {}'.format(len(sys.argv) - 1))\n        exit(1)\n    s = sys.argv[1]\n    files = glob.glob(os.path.join(sys.argv[1], '*.gz'))\n    files = [f for f in files if '_shuffled' not in f]\n    print('Found {} files'.format(len(files)))\n    if len(files) == 0:\n        exit(1)\n    files = split(files, len(files) // merge_files)\n    pool = Pool(processes)\n\n    for _ in tqdm.tqdm(pool.imap_unordered(shuffle, files), total=len(files)):\n        pass\n"
  },
  {
    "path": "scripts/split.sh",
    "content": "#!/usr/bin/env bash\n\nRECORDSIZE=8276 # size in bytes of a record (s, pi, v)\n\nfunction usage()\n{\n  echo \"Watches a directory and copies data to train/test set\"\n  echo \"\"\n  echo \"./split.sh\"\n  echo \"  -h --help\"\n  echo \"  -i --input   The directory where chunks arrive\"\n  echo \"  -o --output  The output directory\"\n  echo \"  -n --window  window size of test + train\"\n  echo \"  -t --train   The training percentage in {1,...,100}\"\n  echo \"  -l --latest  The file to store the largest training file number.\"\n  echo \"\"\n  echo \"Example: ./split.sh -i=/tmp -o=/out -n=2000 -t=95\"\n  echo \"\"\n}\n\nwhile [ \"$1\" != \"\" ]\ndo\n  PARAM=`echo $1 | awk -F= '{print $1}'`\n  VALUE=`echo $1 | awk -F= '{print $2}'`\n  case $PARAM in\n    -h | --help)\n      usage\n      exit\n      ;;\n    -i | --input)\n      INPUTDIR=$VALUE\n      ;;\n    -o | --output)\n      TESTDIR=\"$VALUE/test\"\n      TRAINDIR=\"$VALUE/train\"\n      ;;\n    -n | --window)\n      WINSIZE=$VALUE\n      ;;\n    -t | --train)\n      TRAINPCT=$VALUE\n      ;;\n    -l | --latest)\n      LATESTFILE=$VALUE\n      ;;\n    *)\n      echo \"ERROR: unknown parameter \\\"$PARAM\\\"\"\n      usage\n      exit 1\n      ;;\n  esac\n  shift\ndone\n\nif [ -z \"$LC0LOCKFILE\" ]\nthen\n  echo \"env var LC0LOCKFILE not set\"\n  exit 1\nfi\n\n# clear test and train split dirs\nif [ ! -d \"$TESTDIR\" ] || [ ! -d \"$TRAINDIR\" ]\nthen\n  rm -rf \"$TESTDIR\" \"$TRAINDIR\"\n  mkdir -vp \"$TESTDIR\" \"$TRAINDIR\"\nfi\n\nlet n_test=\"$(ls $TESTDIR | wc -l)\"\nlet n_train=\"$(ls $TRAINDIR | wc -l)\"\nlet n=\"$n_test + $n_train\"\nlet overhead=\"$WINSIZE / 10\"\nlet max=\"$WINSIZE + $overhead + 200\"\nmax_train=$(echo \"scale=1;($TRAINPCT / 100) * $max\" | bc | cut -d'.' -f1)\nmax_test=$(echo \"scale=1;(1 - $TRAINPCT / 100) * $max\" | bc | cut -d'.' -f1)\noverhead_train=$(echo \"scale=1;($TRAINPCT / 100) * $overhead\" | bc | cut -d'.' -f1)\noverhead_test=$(echo \"scale=1;(1 - $TRAINPCT / 100) * $overhead\" | bc | cut -d'.' -f1)\n\nlatest=0\nif [ -f $LATESTFILE ]\nthen\n  latest=$(cat $LATESTFILE)\nfi\n\necho \"\"\necho \"start splitter, found $n games, $n_test test, $n_train train\"\necho \"  max chunks: $max\"\necho \"  max test:   $max_test, trim_by: $overhead_test\"\necho \"  max train:  $max_train, trim_by: $overhead_train\"\necho \"\"\n\n\nprocess() {\n  local dir=$1\n  local file=$2\n\n  if [[ $file = training.*.gz ]]\n  then\n    # compute basic file integrity check\n    size=$(zcat $dir/$file | wc -c)\n    let rem=\"size % $RECORDSIZE\"\n    \n    if [[ $size -eq 0 ]] || [[ $rem -ne 0 ]]\n    then\n      echo -n \"X\"\n      return\n    fi\n\n    # new file, put hard link in correct directory\n    let \"n++\"\n\n    id=$(echo $file | cut -d'.' -f 2)\n    let hash_index=\"$id % 100 + 1\"\n\n    if [ $id -gt $latest ]\n    then\n      latest=$id\n      if [ -f $LATESTFILE ]\n      then\n        echo $latest > $LATESTFILE\n      fi\n    fi\n\n    if [ $hash_index -gt $TRAINPCT ]\n    then\n      let \"n_test++\"\n      target=$TESTDIR/$file\n      echo -n \"T\"\n    else\n      let \"n_train++\"\n      target=$TRAINDIR/$file\n      echo -n \"*\"\n    fi\n\n    ln $dir/$file $target\n\n    # exceeding max buffer size for either, lock and remove overhead as appropriate\n    if [ $n_test -gt $max_test ] || [ $n_train -gt $max_train ]\n    then\n      (\n      flock -e 200\n      if [ $n_test -gt $max_test ]\n      then\n        ls -rt $TESTDIR | head -n $overhead_test | xargs -I{} rm -f $TESTDIR/{}\n        echo -n \"-\"\n      fi\n      if [ $n_train -gt $max_train ]\n      then\n        ls -rt $TRAINDIR | head -n $overhead_train | xargs -I{} rm -f $TRAINDIR/{}\n        echo -n \"_\"\n      fi\n      ) 200>$LC0LOCKFILE\n      if [ $n_test -gt $max_test ]\n      then\n        let \"n -= $overhead_test\"\n        let \"n_test -= $overhead_test\"\n        echo -n \"-\"\n      fi\n      if [ $n_train -gt $max_train ]\n      then\n        let \"n -= $overhead_train\"\n        let \"n_train -= $overhead_train\"\n        echo -n \"_\"\n      fi\n\n    fi\n  fi\n}\n\n\necho \"processing '$INPUTDIR'\"\nfor file in $(./diff.py -i $INPUTDIR -w $WINSIZE $TRAINDIR $TESTDIR)\ndo\n  process $INPUTDIR $file\ndone\n\necho -e \"\\nmonitoring '$INPUTDIR'\"\ninotifywait -q -m -e moved_to -e close_write $INPUTDIR | mbuffer -m 10M |\n  while read dir event file\n  do\n    if [ -f \"$TESTDIR/$file\" ] || [ -f \"$TRAINDIR/$file\" ]\n    then\n      continue\n    fi\n\n    process $INPUTDIR $file\n  done\n"
  },
  {
    "path": "scripts/stage.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nfunction usage()\n{\n  echo \"Moves arriving data to a directory so rescorer can assume all files are complete\"\n  echo \"\"\n  echo \"./stage.sh\"\n  echo \"  -h --help\"\n  echo \"  -i --input   The monitoring directory\"\n  echo \"  -o --output   The directory where output should go\"\n  echo \"\"\n  echo \"Example: ./stage.sh -i data -o data-staged\"\n  echo \"\"\n}\n\nwhile [ \"$1\" != \"\" ]\ndo\n  PARAM=`echo $1 | awk -F= '{print $1}'`\n  VALUE=`echo $1 | awk -F= '{print $2}'`\n  case $PARAM in\n    -h | --help)\n      usage\n      exit\n      ;;\n    -i | --input)\n      INPUTDIR=$VALUE\n      ;;\n    -o | --output)\n      OUTPUTDIR=$VALUE\n      ;;\n    *)\n      echo \"ERROR: unknown parameter \\\"$PARAM\\\"\"\n      usage\n      exit 1\n      ;;\n  esac\n  shift\ndone\n\n\necho \"start data monitor for $INPUTDIR\"\n\ninotifywait -m -e moved_to -e close_write $INPUTDIR | mbuffer -m 10M |\n  while read dir events file\n  do\n    if [[ $file = *.gz ]]\n    then\n      echo -n \".\"\n      mv \"$INPUTDIR/$file\" \"$OUTPUTDIR/\"\n    #else\n      #echo \"ignoring ${file} ($events)\"\n    fi\n  done\n"
  },
  {
    "path": "scripts/unpack.py",
    "content": "#!/usr/bin/env python3\n\nimport os\nimport argparse\nimport gzip\nimport bz2\nimport numpy as np\nimport pickle\nimport struct\nfrom pack import RECORD_SIZE\n\n\ndef unpack(filepath):\n    front, back = os.path.basename(filepath).split('-')\n    back = back.split('.')[0]\n    first = int(front)\n    last = int(back)\n    num_chunks = last - first + 1\n\n    buf = bz2.BZ2File(filepath, 'rb').read()\n    size = struct.unpack('I', buf[-4:])[0]\n    plylist = np.frombuffer(buf[-4 - size:-4], dtype=np.int16)\n    data = np.frombuffer(buf[:-4 - size],\n                         dtype=np.int8).reshape(-1, RECORD_SIZE)\n    assert (num_chunks == len(plylist))\n\n    begin = 0\n    for i, plies in enumerate(plylist):\n        end = begin + plies\n        filename = os.path.join(argv.output,\n                                \"training.{}.gz\".format(i + first))\n\n        with gzip.open(filename, 'wb') as f:\n            for row in data[begin:end]:\n                f.write(row)\n\n        begin = end\n\n    print(\"Written {} chunks\".format(num_chunks))\n\n\ndef main():\n    if not os.path.exists(argv.output):\n        os.makedirs(argv.output)\n        print(\"Created directory '{}'\".format(argv.output))\n\n    unpack(argv.input)\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n            'Unpack *-*.bz2 file into gz chunks.')\n    argparser.add_argument('-i', '--input', type=str, help='input file')\n    argparser.add_argument('-o', '--output', type=str, help='output directory')\n    argv = argparser.parse_args()\n\n    main()\n"
  },
  {
    "path": "scripts/upload.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nfunction usage()\n{\n  echo \"Uploads a network with NxM prefix, where N=filters and M=blocks\"\n  echo \"\"\n  echo \"./upload.sh\"\n  echo \"  -h --help\"\n  echo \"  -u --upload   The upload url\"\n  echo \"  -d --netdir   The directory where new networks arrive\"\n  echo \"  -f --filters  Number of filters\"\n  echo \"  -b --blocks   Number of blocks\"\n  echo \"\"\n  echo \"Example: ./upload.sh -d=/tmp -u=http://upload.me -f=64 -b=6\"\n  echo \"\"\n}\n\nwhile [ \"$1\" != \"\" ]\ndo\n  PARAM=`echo $1 | awk -F= '{print $1}'`\n  VALUE=`echo $1 | awk -F= '{print $2}'`\n  case $PARAM in\n    -h | --help)\n      usage\n      exit\n      ;;\n    -u | --upload)\n      UPLOADURL=$VALUE\n      ;;\n    -d | --netdir)\n      NETDIR=$VALUE\n      ;;\n    -f | --filters)\n      FILTERS=$VALUE\n      ;;\n    -b | --blocks)\n      BLOCKS=$VALUE\n      ;;\n    *)\n      echo \"ERROR: unknown parameter \\\"$PARAM\\\"\"\n      usage\n      exit 1\n      ;;\n  esac\n  shift\ndone\n\nnetarch=\"${FILTERS}x${BLOCKS}\"\n\necho \"start upload monitor for $netarch*.gz\"\n\ninotifywait -m -e moved_to -e close_write $NETDIR |\n  while read dir events file\n  do\n    if [[ $file = ${netarch}*.gz ]]\n    then\n      echo \"uploading ${file} ($events)\"\n      curl -s -F \"file=@${dir}/${file}\" -F \"training_id=1\" -F \"layers=${BLOCKS}\" -F \"filters=${FILTERS}\" $UPLOADURL &\n    else\n      echo \"ignoring ${file} ($events)\"\n    fi\n  done\n"
  },
  {
    "path": "src/lczero_training/__init__.py",
    "content": "\"\"\"Leela Chess Zero training package.\"\"\"\n"
  },
  {
    "path": "src/lczero_training/_lczero_training.pyi",
    "content": "# ABOUTME: Type stubs for C++ DataLoader PyBind11 bindings.\n# ABOUTME: Provides type annotations for _lczero_training compiled module.\n\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\n\nfrom proto.data_loader_config_pb2 import DataLoaderConfig\nfrom proto.stage_control_pb2 import StageControlRequest, StageControlResponse\n\nclass TensorBase:\n    def shape(self) -> List[int]: ...\n    def strides(self) -> List[int]: ...\n    def element_size(self) -> int: ...\n    def py_format(self) -> str: ...\n\nclass DataLoader:\n    def __init__(self, config: DataLoaderConfig | bytes) -> None: ...\n    def add_stages(self, config: DataLoaderConfig | bytes) -> None: ...\n    def send_control_message(\n        self, request: StageControlRequest | bytes\n    ) -> List[Tuple[str, StageControlResponse]]: ...\n    def start(self) -> None: ...\n    def get_next(self, alias: str = \"\") -> Tuple[np.ndarray, ...]: ...\n    def maybe_get_next(\n        self, alias: str = \"\"\n    ) -> Optional[Tuple[np.ndarray, ...]]: ...\n    def stop(self) -> None: ...\n    def get_bucket_metrics(\n        self, time_period: int, include_pending: bool\n    ) -> Tuple[bytes, float]: ...\n    def get_aggregate_ending_now(\n        self, duration_seconds: float, include_pending: bool\n    ) -> Tuple[bytes, float]: ...\n"
  },
  {
    "path": "src/lczero_training/commands/__init__.py",
    "content": "\"\"\"Command entrypoint scaffolding and shared CLI helpers.\n\nThis package will host thin wrappers for individual tools (convert,\ntraining, daemon, tui) as they are extracted from nested module\n__main__ dispatchers in subsequent phases.\n\nPhase 1 provides common helpers for consistent logging and argument\nhandling across commands without changing behaviour.\n\"\"\"\n\nfrom .common import (\n    add_logging_arguments,\n    configure_root_logging,\n    parse_log_level,\n)\n\n__all__ = [\n    \"configure_root_logging\",\n    \"add_logging_arguments\",\n    \"parse_log_level\",\n]\n"
  },
  {
    "path": "src/lczero_training/commands/backfill_metrics.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_training.training.backfill_metrics import backfill_metrics\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Backfill metrics for existing checkpoints.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the RootConfig textproto config.\",\n    )\n    parser.add_argument(\n        \"--metrics\",\n        nargs=\"+\",\n        required=True,\n        help=\"Names of metrics to backfill (must be NPZ metrics).\",\n    )\n    parser.add_argument(\n        \"--min-step\",\n        type=int,\n        help=\"Minimum checkpoint step (inclusive) to process.\",\n    )\n    parser.add_argument(\n        \"--max-step\",\n        type=int,\n        help=\"Maximum checkpoint step (inclusive) to process.\",\n    )\n    parser.add_argument(\n        \"--migration-config\",\n        help=(\n            \"Path to a CheckpointMigrationConfig textproto file. \"\n            \"If provided, checkpoints will be migrated before evaluation.\"\n        ),\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to avoid heavy deps unless executing the command.\n\n    backfill_metrics(\n        config_path=args.config,\n        metric_names=args.metrics,\n        min_step=args.min_step,\n        max_step=args.max_step,\n        migration_config_path=args.migration_config,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/common.py",
    "content": "import argparse\nimport logging\nimport os\nimport sys\n\n_DEFAULT_FORMAT = (\n    \"%(levelname).1s%(asctime)s.%(msecs)03d %(name)s \"\n    \"%(filename)s:%(lineno)d] %(message)s\"\n)\n_DEFAULT_DATEFMT = \"%m%d %H:%M:%S\"\n\n\ndef configure_root_logging(level: int | str = logging.INFO) -> None:\n    \"\"\"Configure root logging with a consistent, terse format.\n\n    - Matches existing project format used in module __main__ files.\n    - Respects explicit level passed as int or name (e.g. \"INFO\").\n    - Uses stderr by default.\n    - Forces reconfiguration to avoid duplicate handlers during nested runs.\n    \"\"\"\n\n    resolved_level = parse_log_level(level)\n    logging.basicConfig(\n        level=resolved_level,\n        format=_DEFAULT_FORMAT,\n        datefmt=_DEFAULT_DATEFMT,\n        stream=sys.stderr,\n        force=True,\n    )\n\n\ndef parse_log_level(level: int | str) -> int:\n    \"\"\"Parse log level from int or string, with sane defaults.\n\n    Accepts numeric levels or case-insensitive names like \"DEBUG\".\n    Falls back to INFO on invalid input.\n    \"\"\"\n    if isinstance(level, int):\n        return level\n    if isinstance(level, str):\n        name = level.strip().upper()\n        return getattr(logging, name, logging.INFO)\n    return logging.INFO\n\n\ndef add_logging_arguments(parser: argparse.ArgumentParser) -> None:\n    \"\"\"Add common logging CLI arguments to a parser.\n\n    Does not enable the flags by default; commands may opt-in and then call\n    configure_root_logging(parse_log_level(args.log_level)) if present.\n    \"\"\"\n    parser.add_argument(\n        \"--log-level\",\n        default=os.environ.get(\"LCZERO_LOG_LEVEL\", \"INFO\"),\n        help=\"Logging level (DEBUG, INFO, WARNING, ERROR).\",\n    )\n"
  },
  {
    "path": "src/lczero_training/commands/daemon.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_training.daemon.daemon import TrainingDaemon\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Run the training daemon.\")\n    parser.add_argument(\"--memory-profile-dir\", default=None)\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    daemon = TrainingDaemon(memory_profile_dir=args.memory_profile_dir)\n    daemon.run()\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/dataloader_viz.py",
    "content": "import argparse\nimport sys\n\nfrom google.protobuf import text_format\nfrom graphviz import Digraph  # type: ignore\n\nfrom lczero_training.commands import configure_root_logging\nfrom proto.root_config_pb2 import RootConfig\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Visualize the data loader pipeline as a graph.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        required=True,\n        help=\"Output path for the visualization (.svg or .png).\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging()\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    config = RootConfig()\n    with open(args.config, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    dot = Digraph(comment=\"DataLoader Pipeline\")\n    dot.attr(rankdir=\"TB\")\n    dot.attr(\n        \"node\",\n        shape=\"box\",\n        fontname=\"monospace\",\n        fontsize=\"10\",\n        labeljust=\"l\",\n    )\n\n    stage_names = set()\n    for stage in config.data_loader.stage:\n        stage_names.add(stage.name)\n        stage_text = text_format.MessageToString(stage, as_one_line=False)\n        br_tag = '<br align=\"left\"/>'\n        escaped_text = (\n            stage_text.replace(\"&\", \"&amp;\")\n            .replace(\"<\", \"&lt;\")\n            .replace(\">\", \"&gt;\")\n            .replace(\"\\n\", br_tag)\n        )\n        label = f\"<{escaped_text}>\"\n        dot.node(stage.name, label=label, shape=\"box\")\n\n        for input_spec in stage.input:\n            parts = input_spec.split(\".\", 1)\n            source_stage = parts[0]\n            if len(parts) == 2:\n                dot.edge(source_stage, stage.name, label=parts[1])\n            else:\n                dot.edge(source_stage, stage.name)\n\n    for output_spec in config.data_loader.output:\n        parts = output_spec.split(\":\", 1)\n        if len(parts) == 2:\n            alias, source = parts\n        else:\n            alias = output_spec\n            source = output_spec\n\n        source_parts = source.split(\".\", 1)\n        source_stage = source_parts[0]\n        dot.node(\n            f\"output_{alias}\",\n            label=f\"Output: {alias}\",\n            shape=\"ellipse\",\n            style=\"filled\",\n            fillcolor=\"lightblue\",\n        )\n        if len(source_parts) == 2:\n            dot.edge(source_stage, f\"output_{alias}\", label=source_parts[1])\n        else:\n            dot.edge(source_stage, f\"output_{alias}\")\n\n    output_format = args.output.rsplit(\".\", 1)[-1].lower()\n    if output_format not in (\"svg\", \"png\"):\n        output_format = \"svg\"\n\n    dot.render(\n        outfile=args.output, format=output_format, cleanup=True, engine=\"dot\"\n    )\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/describe_training.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Describe a trained model.\")\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--shapes\",\n        action=\"store_true\",\n        help=\"Dump model shapes.\",\n    )\n    parser.add_argument(\n        \"--values\",\n        action=\"store_true\",\n        help=\"Dump model values.\",\n    )\n    parser.add_argument(\n        \"--weight_paths\",\n        action=\"store_true\",\n        help=\"List all weight paths.\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Import on demand to avoid importing heavy deps on --help.\n    from lczero_training.training.describe import describe\n\n    describe(\n        config_filename=args.config,\n        shapes=args.shapes,\n        values=args.values,\n        weight_paths=args.weight_paths,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/jax2leela.py",
    "content": "import argparse\nimport gzip\nimport logging\nimport os\nimport sys\n\nimport orbax.checkpoint as ocp\nfrom google.protobuf import text_format\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_training.convert.jax_to_leela import (\n    LeelaExportOptions,\n    jax_to_leela,\n)\nfrom lczero_training.training.state import TrainingState\nfrom proto.root_config_pb2 import RootConfig\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Export JAX checkpoint to Leela format.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        required=True,\n        help=\"Output path for the Leela network (.pb.gz).\",\n    )\n    parser.add_argument(\n        \"--export-swa\",\n        action=\"store_true\",\n        help=\"Export SWA model instead of regular model_state.\",\n    )\n    parser.add_argument(\n        \"--min-version\",\n        type=str,\n        default=\"0.31\",\n        help=\"Minimum lc0 version for exported network (default: 0.31).\",\n    )\n    return parser\n\n\ndef jax2leela(\n    config_filename: str,\n    output_path: str,\n    export_swa: bool,\n    min_version: str,\n) -> None:\n    config = RootConfig()\n    logging.info(\"Reading configuration from %s\", config_filename)\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    if not config.training.checkpoint.path:\n        logging.error(\"Checkpoint path must be set in the configuration.\")\n        sys.exit(1)\n\n    logging.info(\"Loading checkpoint from %s\", config.training.checkpoint.path)\n    checkpoint_mgr = ocp.CheckpointManager(\n        config.training.checkpoint.path,\n        options=ocp.CheckpointManagerOptions(create=True),\n    )\n\n    empty_state = TrainingState.new_from_config(\n        model_config=config.model,\n        training_config=config.training,\n    )\n\n    restored_state = checkpoint_mgr.restore(\n        checkpoint_mgr.latest_step(), args=ocp.args.PyTreeRestore(empty_state)\n    )\n    assert isinstance(restored_state, TrainingState)\n    logging.info(\n        \"Restored checkpoint at step %d\", restored_state.jit_state.step\n    )\n\n    if export_swa:\n        if restored_state.jit_state.swa_state is None:\n            logging.error(\n                \"SWA export requested but SWA state is None in checkpoint.\"\n            )\n            sys.exit(1)\n        export_state = restored_state.jit_state.swa_state\n        logging.info(\"Exporting SWA model\")\n    else:\n        export_state = restored_state.jit_state.model_state\n        logging.info(\"Exporting regular model\")\n\n    options = LeelaExportOptions(\n        min_version=min_version,\n        num_heads=restored_state.num_heads,\n        license=None,\n        training_steps=restored_state.jit_state.step,\n    )\n\n    logging.info(\"Converting to Leela format\")\n    net = jax_to_leela(jax_weights=export_state, export_options=options)\n\n    logging.info(\"Serializing network\")\n    network_bytes = gzip.compress(net.SerializeToString())\n\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    logging.info(\"Writing network to %s\", output_path)\n    with open(output_path, \"wb\") as f:\n        f.write(network_bytes)\n\n    logging.info(\"Export complete\")\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    jax2leela(\n        config_filename=args.config,\n        output_path=args.output,\n        export_swa=args.export_swa,\n        min_version=args.min_version,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/leela2jax.py",
    "content": "import argparse\nimport sys\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Convert Leela Zero weights to JAX format.\"\n    )\n    parser.add_argument(\n        \"input\", type=str, help=\"Path to the input Lc0 weights file.\"\n    )\n    parser.add_argument(\n        \"--output-model-config\",\n        type=str,\n        help=\"Output path to the ModelConfig textproto.\",\n    )\n    parser.add_argument(\n        \"--weights-dtype\",\n        default=\"F32\",\n        type=str,\n        help=\"The data type of the weights.\",\n    )\n    parser.add_argument(\n        \"--compute-dtype\",\n        default=\"F32\",\n        type=str,\n        help=\"The data type for computation.\",\n    )\n    parser.add_argument(\n        \"--print-model-config\",\n        action=\"store_true\",\n        help=\"Print the ModelConfig textproto to stdout.\",\n    )\n    parser.add_argument(\n        \"--output-serialized-jax\",\n        type=str,\n        help=\"Path to save the output JAX serialized state.\",\n    )\n    parser.add_argument(\n        \"--output-leela-verification\",\n        type=str,\n        help=(\n            \"Path to save the round-trip converted Leela network (.pb.gz) for \"\n            \"verification.\"\n        ),\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Import on demand to avoid importing heavy deps on --help.\n    from lczero_training.convert.leela_to_jax import (\n        leela_to_jax_files,\n    )\n\n    leela_to_jax_files(\n        input_path=args.input,\n        weights_dtype=args.weights_dtype,\n        compute_dtype=args.compute_dtype,\n        output_modelconfig=args.output_model_config,\n        output_serialized_jax=args.output_serialized_jax,\n        output_leela_verification=args.output_leela_verification,\n        print_modelconfig=args.print_model_config,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/migrate_checkpoint.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Migrate a checkpoint to a new training state.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the RootConfig textproto config.\",\n    )\n    parser.add_argument(\n        \"--new_checkpoint\",\n        help=(\n            \"Path to save the new checkpoint to. If not set, the tool only \"\n            \"checks whether the migration rules fully cover the differences.\"\n        ),\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"If set, allows overwriting existing checkpoint.\",\n    )\n    parser.add_argument(\n        \"--rules_file\",\n        help=(\n            \"Path to a CheckpointMigrationConfig textproto file containing \"\n            \"the migration rules.\"\n        ),\n    )\n    parser.add_argument(\n        \"--serialized-model\",\n        action=\"store_true\",\n        default=False,\n        help=\"Use serialized state for a model.\",\n    )\n    parser.add_argument(\n        \"--checkpoint_step\",\n        type=int,\n        help=(\n            \"If set, use this step when loading from old checkpoint instead \"\n            \"of the latest.\"\n        ),\n    )\n    parser.add_argument(\n        \"--new_checkpoint_step\",\n        type=int,\n        help=(\n            \"If set, use this step when saving the new checkpoint instead of \"\n            \"copying the old step.\"\n        ),\n    )\n    parser.add_argument(\"--dump_source_paths\", action=\"store_true\")\n    parser.add_argument(\"--dump_destination_paths\", action=\"store_true\")\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to avoid heavy deps unless executing the command.\n    from lczero_training.training.migrate_checkpoint import (\n        migrate_checkpoint,\n    )\n\n    migrate_checkpoint(\n        config=args.config,\n        new_checkpoint=args.new_checkpoint,\n        overwrite=args.overwrite,\n        rules_file=args.rules_file,\n        serialized_model=args.serialized_model,\n        checkpoint_step=args.checkpoint_step,\n        new_checkpoint_step=args.new_checkpoint_step,\n        dump_source_paths=args.dump_source_paths,\n        dump_destination_paths=args.dump_destination_paths,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/overfit.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Run an overfitting test on a single batch.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--num-steps\",\n        type=int,\n        required=True,\n        help=\"Number of training steps to run on the fixed batch.\",\n    )\n    parser.add_argument(\n        \"--coin-flip\",\n        action=\"store_true\",\n        help=(\n            \"Train on two batches: first train batch A while evaluating batch B, then vice versa.\"\n        ),\n    )\n    parser.add_argument(\n        \"--csv-file\",\n        type=str,\n        help=\"Optional path to write step-by-step overfit results.\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Import on demand to avoid importing heavy deps on --help.\n    from lczero_training.training.overfit import overfit\n\n    overfit(\n        config_filename=args.config,\n        num_steps=args.num_steps,\n        coin_flip=args.coin_flip,\n        csv_file=args.csv_file,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/test_dataloader.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=(\n            \"Fetch batches from the data loader to measure latency and throughput.\"\n        )\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--num-batches\",\n        type=int,\n        default=10,\n        help=\"Number of batches to fetch from the data loader.\",\n    )\n    parser.add_argument(\n        \"--npz-output\",\n        type=str,\n        help=\"Optional path to store fetched batches as an .npz archive.\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Import on demand to avoid importing heavy deps on --help.\n    from lczero_training.training.dataloader_probe import probe_dataloader\n\n    probe_dataloader(\n        config_filename=args.config,\n        num_batches=args.num_batches,\n        npz_output=args.npz_output,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/train.py",
    "content": "import argparse\nimport datetime\nimport gzip\nimport logging\nimport os\nimport sys\n\nimport orbax.checkpoint as ocp\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_training.convert.jax_to_leela import (\n    LeelaExportOptions,\n    jax_to_leela,\n)\nfrom lczero_training.dataloader import make_dataloader\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.lr_schedule import make_lr_schedule\nfrom lczero_training.training.optimizer import make_gradient_transformation\nfrom lczero_training.training.state import TrainingState\nfrom lczero_training.training.training import Training, from_dataloader\nfrom proto.root_config_pb2 import RootConfig\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Start a training run.\")\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    return parser\n\n\ndef train(config_filename: str) -> None:\n    config = RootConfig()\n    logging.info(\"Reading configuration from proto file\")\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    if config.training.checkpoint.path is None:\n        logging.error(\"Checkpoint path must be set in the configuration.\")\n        sys.exit(1)\n\n    checkpoint_mgr = ocp.CheckpointManager(\n        config.training.checkpoint.path,\n        options=ocp.CheckpointManagerOptions(\n            create=True,\n        ),\n    )\n\n    logging.info(\"Creating state from configuration\")\n    empty_state = TrainingState.new_from_config(\n        model_config=config.model,\n        training_config=config.training,\n    )\n    logging.info(\"Restoring checkpoint\")\n    training_state = checkpoint_mgr.restore(\n        None, args=ocp.args.PyTreeRestore(empty_state)\n    )\n    logging.info(\"Restored checkpoint\")\n\n    model, _ = nnx.split(\n        LczeroModel(config=config.model, rngs=nnx.Rngs(params=42))\n    )\n\n    assert isinstance(training_state, TrainingState)\n\n    jit_state = training_state.jit_state\n    lr_sched = make_lr_schedule(config.training.lr_schedule)\n    optimizer_tx = make_gradient_transformation(\n        config.training.optimizer,\n        max_grad_norm=getattr(config.training, \"max_grad_norm\", 0.0),\n        lr_schedule=lr_sched,\n    )\n    training = Training(\n        optimizer_tx=optimizer_tx,\n        graphdef=model,\n        loss_fn=LczeroLoss(config=config.training.losses),\n        swa_config=(\n            config.training.swa if config.training.HasField(\"swa\") else None\n        ),\n    )\n    new_state = training.run(\n        jit_state,\n        from_dataloader(make_dataloader(config.data_loader)),\n        config.training.schedule.steps_per_network,\n    )\n\n    if config.export.destination_filename:\n        date_str = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n\n        logging.info(\"Exporting network\")\n\n        options = LeelaExportOptions(\n            min_version=\"0.28\",\n            num_heads=training_state.num_heads,\n            license=None,\n            training_steps=new_state.step,\n        )\n        export_state = (\n            new_state.swa_state\n            if config.export.export_swa_model\n            else new_state.model_state\n        )\n        assert isinstance(export_state, nnx.State)\n        net = jax_to_leela(jax_weights=export_state, export_options=options)\n        network_bytes = gzip.compress(net.SerializeToString())\n\n        for destination_template in config.export.destination_filename:\n            destination = destination_template.format(\n                datetime=date_str, step=new_state.step\n            )\n            logging.info(f\"Writing network to {destination}\")\n            os.makedirs(os.path.dirname(destination), exist_ok=True)\n            with open(destination, \"wb\") as f:\n                f.write(network_bytes)\n            logging.info(f\"Finished writing network to {destination}\")\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    train(config_filename=args.config)\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/training_eval.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Evaluate a trained model.\")\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--num-samples\",\n        type=int,\n        help=\"Number of samples to evaluate.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        help=\"Override batch size from data loader config.\",\n    )\n    parser.add_argument(\n        \"--dump-stdout\",\n        action=\"store_true\",\n        help=\"Dump input/output tensors to stdout.\",\n    )\n    parser.add_argument(\n        \"--dump-file\",\n        type=str,\n        help=\"Dump input/output tensors to specified file.\",\n    )\n    parser.add_argument(\n        \"--dump-shelve\",\n        type=str,\n        help=\"Dump input/output tensors to specified shelve database.\",\n    )\n    parser.add_argument(\n        \"--dump-json\",\n        type=str,\n        help=\"Dump input/output tensors to specified JSON file.\",\n    )\n    parser.add_argument(\n        \"--onnx-model\",\n        type=str,\n        help=\"Path to an ONNX model to compare against JAX outputs.\",\n    )\n    parser.add_argument(\n        \"--no-softmax-jax-wdl\",\n        action=\"store_true\",\n        help=\"Disable softmaxing the JAX WDL head before comparison.\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to keep --help responsive and avoid heavy deps unless needed.\n    from lczero_training.training.eval import eval as eval_fn\n\n    eval_fn(\n        config_filename=args.config,\n        num_samples=args.num_samples,\n        batch_size_override=args.batch_size,\n        dump_to_stdout=args.dump_stdout,\n        dump_to_file=args.dump_file,\n        dump_to_shelve=args.dump_shelve,\n        dump_to_json=args.dump_json,\n        onnx_model=args.onnx_model,\n        softmax_jax_wdl=not args.no_softmax_jax_wdl,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/training_init.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Initialize a new training run from a config file.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--lczero_model\",\n        type=str,\n        help=\"Path to an existing lczero model to start from.\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=42,\n        help=\"Seed for initializing model parameters.\",\n    )\n    parser.add_argument(\n        \"--dry-run\",\n        action=\"store_true\",\n        help=\"Skip checkpoint creation.\",\n    )\n    parser.add_argument(\n        \"--swa_initial_nets\",\n        type=int,\n        default=0,\n        help=\"Initial value for num_averages in SWA state.\",\n    )\n    parser.add_argument(\n        \"--override_training_steps\",\n        type=int,\n        help=\"Override training step number.\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Allow overwriting existing checkpoint.\",\n    )\n    parser.add_argument(\n        \"--no-copy-swa\",\n        action=\"store_true\",\n        help=\"Don't copy model weights to SWA state.\",\n    )\n    parser.add_argument(\n        \"--ignore-config-mismatch\",\n        action=\"store_true\",\n        help=\"Ignore lczero model config mismatch.\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to keep --help responsive and avoid heavy deps unless needed.\n    from lczero_training.training.init import init\n\n    init(\n        config_filename=args.config,\n        lczero_model=args.lczero_model,\n        seed=args.seed,\n        dry_run=args.dry_run,\n        swa_initial_nets=args.swa_initial_nets,\n        override_training_steps=args.override_training_steps,\n        overwrite=args.overwrite,\n        no_copy_swa=args.no_copy_swa,\n        ignore_config_mismatch=args.ignore_config_mismatch,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/tui.py",
    "content": "import argparse\nimport logging\nimport sys\n\nimport anyio\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_training.tui.app import TrainingTuiApp\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Training TUI runner\")\n    TrainingTuiApp.add_arguments(parser)\n    return parser\n\n\nasync def _amain(args: argparse.Namespace) -> None:\n    app = TrainingTuiApp(args=args)\n    await app.run_async()\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n    anyio.run(_amain, args)\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/tune_lr.py",
    "content": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Run a learning rate tuning sweep.\"\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        required=True,\n        help=\"Path to the training config file.\",\n    )\n    parser.add_argument(\n        \"--start-lr\",\n        type=float,\n        required=True,\n        help=\"Starting learning rate for the sweep.\",\n    )\n    parser.add_argument(\n        \"--num-steps\",\n        type=int,\n        required=True,\n        help=\"Number of training steps to evaluate.\",\n    )\n    parser.add_argument(\n        \"--multiplier\",\n        type=float,\n        default=1.01,\n        help=\"Multiplier applied to the learning rate after each step.\",\n    )\n    parser.add_argument(\n        \"--warmup-steps\",\n        type=int,\n        default=0,\n        help=(\n            \"Optional number of warmup steps to run at a fixed learning rate before \"\n            \"the exponential sweep.\"\n        ),\n    )\n    parser.add_argument(\n        \"--warmup-lr\",\n        type=float,\n        help=(\n            \"Learning rate to use during warmup steps. Required when --warmup-steps > 0.\"\n        ),\n    )\n    parser.add_argument(\n        \"--csv-output\",\n        type=str,\n        help=(\n            \"Optional path to write CSV results. Columns: lr, train_loss[, val_loss].\"\n        ),\n    )\n    parser.add_argument(\n        \"--plot-output\",\n        type=str,\n        help=\"Optional path to save a matplotlib plot of the sweep.\",\n    )\n    parser.add_argument(\n        \"--num-test-batches\",\n        type=int,\n        default=0,\n        help=(\n            \"When > 0, also compute and report validation loss on this many fixed batches \"\n            \"(averaged each step). Default 0 (training loss only).\"\n        ),\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging(logging.INFO)\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to keep --help responsive and avoid heavy deps unless needed.\n    from lczero_training.training.tune_lr import tune_lr\n\n    tune_lr(\n        config_filename=args.config,\n        start_lr=args.start_lr,\n        num_steps=args.num_steps,\n        multiplier=args.multiplier,\n        warmup_steps=args.warmup_steps,\n        warmup_lr=args.warmup_lr,\n        csv_output=args.csv_output,\n        plot_output=args.plot_output,\n        num_test_batches=args.num_test_batches,\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/commands/weights_tool.py",
    "content": "\"\"\"CLI command for manipulating Lc0 neural network weights.\"\"\"\n\nimport argparse\nimport sys\n\nimport numpy as np\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(\n        description=\"Manipulate Lc0 neural network weights.\"\n    )\n    parser.add_argument(\n        \"--expr\",\n        type=str,\n        help=(\n            \"Python expression to execute \"\n            \"(e.g., \\\"output = weights('A.pb') * 0.5\\\")\"\n        ),\n    )\n    parser.add_argument(\n        \"script\",\n        nargs=\"?\",\n        help=\"Path to Python script file (if --expr not provided)\",\n    )\n    parser.add_argument(\n        \"--input\",\n        type=str,\n        action=\"append\",\n        help=\"Pre-load input as NAME=PATH (e.g., --input A=net_A.pb.gz)\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        help='Default output path if \"output\" variable is set',\n    )\n    parser.add_argument(\n        \"--encoding\",\n        type=str,\n        default=\"FLOAT16\",\n        choices=[\"LINEAR16\", \"FLOAT16\", \"BFLOAT16\"],\n        help=\"Output encoding format (default: FLOAT16)\",\n    )\n    return parser\n\n\ndef main(argv: list[str] | None = None) -> int:\n    configure_root_logging()\n\n    parser = _build_parser()\n    args = parser.parse_args(argv)\n\n    # Lazy import to avoid heavy dependencies on --help.\n    from lczero_training.tools.weights_tool import load_weights, save_weights\n    from proto import net_pb2\n\n    # Build execution environment.\n    env = {\n        \"np\": np,\n        \"weights\": load_weights,\n        \"save\": save_weights,\n        \"lc0\": net_pb2,\n    }\n\n    # Pre-load inputs.\n    if args.input:\n        for input_spec in args.input:\n            if \"=\" not in input_spec:\n                print(\n                    f\"Error: Invalid input spec '{input_spec}'. \"\n                    \"Expected format: NAME=PATH\",\n                    file=sys.stderr,\n                )\n                return 1\n            name, path = input_spec.split(\"=\", 1)\n            env[name] = load_weights(path)\n\n    # Determine script source: --expr, file, or stdin.\n    if args.expr:\n        script = args.expr\n    elif args.script:\n        with open(args.script) as f:\n            script = f.read()\n    else:\n        if sys.stdin.isatty():\n            print(\n                \"Error: No script provided. Use --expr, provide script file, \"\n                \"or pipe to stdin.\",\n                file=sys.stderr,\n            )\n            return 1\n        script = sys.stdin.read()\n\n    # Execute script.\n    try:\n        exec(script, env)\n    except Exception as e:\n        print(f\"Error executing script: {e}\", file=sys.stderr)\n        return 1\n\n    # Auto-save if 'output' variable is set.\n    if \"output\" in env and args.output:\n        from lczero_training.tools.weight_wrappers import NetWrapper\n\n        output = env[\"output\"]\n        if isinstance(output, NetWrapper):\n            save_weights(output, args.output, args.encoding)\n            print(f\"Saved result to {args.output}\")\n\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n"
  },
  {
    "path": "src/lczero_training/convert/__init__.py",
    "content": "\"\"\"Convert package for Leela Chess Zero training.\"\"\"\n"
  },
  {
    "path": "src/lczero_training/convert/jax_to_leela.py",
    "content": "import dataclasses\nimport logging\nfrom typing import Optional, cast\n\nimport numpy as np\nfrom flax import nnx\n\nfrom lczero_training.convert.leela_pytree_visitor import (\n    LeelaPytreeWeightsVisitor,\n)\nfrom proto import net_pb2\n\nlogger = logging.getLogger(__name__)\n\n_EMBEDDING_PLANE_TO_SCALE = 109\n_EMBEDDING_SCALE = 99.0\n\n\nclass JaxToLeela(LeelaPytreeWeightsVisitor):\n    def embedding_block(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        embedding_kernel = cast(nnx.Param, nnx_dict[\"embedding\"][\"kernel\"])\n        original_values = embedding_kernel.value\n        arr = np.asarray(original_values).copy()\n        arr[_EMBEDDING_PLANE_TO_SCALE] /= _EMBEDDING_SCALE\n        embedding_kernel.value = arr\n        try:\n            super().embedding_block(nnx_dict=nnx_dict, weights=weights)\n        finally:\n            embedding_kernel.value = original_values\n\n    def tensor(\n        self,\n        param: nnx.Param,\n        leela: net_pb2.Weights.Layer,\n    ) -> None:\n        weights = np.asarray(param.value, dtype=np.float32).T.flatten()\n        min_val, max_val = np.min(weights), np.max(weights)\n        range_val = max_val - min_val\n\n        # Normalize to [0, 1], handling the case where all weights are equal.\n        normalized = np.where(\n            range_val > 1e-8, (weights - min_val) / range_val, 0.5\n        )\n\n        # Scale to uint16 and convert to bytes.\n        quantized = np.round(normalized * 65535.0).astype(np.uint16)\n        leela.params = quantized.tobytes()\n        leela.min_val = float(min_val)\n        leela.max_val = float(max_val)\n\n        assert len(leela.params) // 2 == weights.size\n\n    def encoder_tower(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        for i in range(len(nnx_dict[\"encoders\"][\"layers\"])):\n            weights.encoder.append(weights.EncoderLayer())\n        return super().encoder_tower(nnx_dict=nnx_dict, weights=weights)\n\n\n@dataclasses.dataclass\nclass LeelaExportOptions:\n    min_version: str\n    num_heads: int\n    license: Optional[str]\n    training_steps: Optional[int] = None\n\n\ndef jax_to_leela(\n    jax_weights: nnx.State, export_options: LeelaExportOptions\n) -> net_pb2.Net:\n    lc0_weights = net_pb2.Net()\n    lc0_weights.magic = 0x1C0\n    if export_options.license:\n        lc0_weights.license = export_options.license\n    (\n        lc0_weights.min_version.major,\n        lc0_weights.min_version.minor,\n        lc0_weights.min_version.patch,\n    ) = _split_version(export_options.min_version)\n    lc0_weights.format.CopyFrom(_make_format())\n    if export_options.training_steps is not None:\n        lc0_weights.training_params.training_steps = (\n            export_options.training_steps\n        )\n\n    visitor = JaxToLeela(jax_weights, lc0_weights)\n    lc0_weights.weights.headcount = export_options.num_heads\n    visitor.run()\n\n    return lc0_weights\n\n\ndef _split_version(version_str: str) -> tuple[int, int, int]:\n    \"\"\"Splits a version string like \"v12.34.56\" into (12, 34, 56).\"\"\"\n    parts = (version_str.lstrip(\"v\").split(\".\") + [\"0\", \"0\"])[:3]\n    return cast(tuple[int, int, int], tuple(map(int, parts)))\n\n\ndef _make_format() -> net_pb2.Format:\n    fmt = net_pb2.Format()\n    fmt.weights_encoding = fmt.LINEAR16\n    netfmt = fmt.network_format\n    netfmt.input = netfmt.INPUT_CLASSICAL_112_PLANE\n    netfmt.output = netfmt.OUTPUT_WDL\n    netfmt.network = netfmt.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT\n    netfmt.policy = netfmt.POLICY_ATTENTION\n    netfmt.value = netfmt.VALUE_WDL\n    netfmt.moves_left = netfmt.MOVES_LEFT_V1\n    netfmt.default_activation = netfmt.DEFAULT_ACTIVATION_MISH\n    netfmt.smolgen_activation = netfmt.ACTIVATION_SWISH\n    netfmt.ffn_activation = netfmt.ACTIVATION_DEFAULT\n    netfmt.input_embedding = netfmt.INPUT_EMBEDDING_PE_DENSE\n\n    return fmt\n"
  },
  {
    "path": "src/lczero_training/convert/leela_pytree_visitor.py",
    "content": "import math\nfrom typing import Any, Optional\n\nfrom flax import nnx\n\nfrom proto import net_pb2\n\n\nclass LeelaPytreeWeightsVisitor:\n    def __init__(self, nnx_state: nnx.State, leela_net: net_pb2.Net) -> None:\n        self.leela_net = leela_net\n        self.nnx_state = nnx_state\n\n    def run(self) -> None:\n        state = self.nnx_state\n        weights = self.leela_net.weights\n        self.embedding_block(state[\"embedding\"], weights)\n        self.encoder_tower(state[\"encoders\"], weights)\n        self.policy_heads(state, weights.policy_heads)\n        for head_name in [\"winner\", \"q\", \"st\"]:\n            if head_name in state[\"value_heads\"]:\n                self.value_head(\n                    state[\"value_heads\"][head_name],\n                    getattr(weights.value_heads, head_name),\n                )\n        for head_name in [\"main\"]:\n            assert head_name in state[\"movesleft_heads\"], (\n                f\"movesleft head {head_name} missing in state\"\n            )\n            self.movesleft_head(state[\"movesleft_heads\"][head_name], weights)\n\n    def embedding_block(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        self.matmul(\n            nnx_dict[\"preprocess\"],\n            weights.ip_emb_preproc_w,\n            weights.ip_emb_preproc_b,\n        )\n        self.matmul(\n            nnx_dict[\"embedding\"],\n            weights.ip_emb_w,\n            weights.ip_emb_b,\n        )\n        self.layernorm(\n            nnx_dict[\"norm\"],\n            weights.ip_emb_ln_gammas,\n            weights.ip_emb_ln_betas,\n        )\n        self.tensor(\n            nnx_dict[\"ma_gating\"][\"mult_gate\"][\"gate\"], weights.ip_mult_gate\n        )\n        self.tensor(\n            nnx_dict[\"ma_gating\"][\"add_gate\"][\"gate\"], weights.ip_add_gate\n        )\n        self.ffn(nnx_dict[\"ffn\"], weights.ip_emb_ffn)\n        self.layernorm(\n            nnx_dict[\"out_norm\"],\n            weights.ip_emb_ffn_ln_gammas,\n            weights.ip_emb_ffn_ln_betas,\n        )\n\n    def encoder_tower(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        # Shared layer is stored at the point of the first usage.\n        self.matmul(\n            nnx_dict[\"encoders\"][\"layers\"][0][\"mha\"][\"smolgen\"][\n                \"weight_gen_dense\"\n            ],\n            weights.smolgen_w,\n            None,\n        )\n\n        # assert len(nnx_dict[\"encoders\"][\"layers\"]) == len(weights.encoder)\n        for i in range(len(nnx_dict[\"encoders\"][\"layers\"])):\n            self.encoder_block(\n                nnx_dict[\"encoders\"][\"layers\"][i], weights.encoder[i]\n            )\n\n    def encoder_block(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights.EncoderLayer\n    ) -> None:\n        self.mha(nnx_dict[\"mha\"], weights.mha)\n        self.layernorm(nnx_dict[\"ln1\"], weights.ln1_gammas, weights.ln1_betas)\n        self.ffn(nnx_dict[\"ffn\"], weights.ffn)\n        self.layernorm(nnx_dict[\"ln2\"], weights.ln2_gammas, weights.ln2_betas)\n\n    def mha(self, nnx_dict: nnx.State, weights: net_pb2.Weights.MHA) -> None:\n        self.matmul(nnx_dict[\"q\"], weights.q_w, weights.q_b)\n        self.matmul(nnx_dict[\"k\"], weights.k_w, weights.k_b)\n        self.matmul(nnx_dict[\"v\"], weights.v_w, weights.v_b)\n        self.smolgen(nnx_dict[\"smolgen\"], weights.smolgen)\n        self.matmul(nnx_dict[\"output_dense\"], weights.dense_w, weights.dense_b)\n\n    def smolgen(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights.Smolgen\n    ) -> None:\n        self.matmul(nnx_dict[\"compress\"], weights.compress, None)\n        self.matmul(nnx_dict[\"dense1\"], weights.dense1_w, weights.dense1_b)\n        self.layernorm(nnx_dict[\"ln1\"], weights.ln1_gammas, weights.ln1_betas)\n        self.matmul(nnx_dict[\"dense2\"], weights.dense2_w, weights.dense2_b)\n        self.layernorm(nnx_dict[\"ln2\"], weights.ln2_gammas, weights.ln2_betas)\n\n    def layernorm(\n        self,\n        nnx_dict: nnx.State,\n        scales: net_pb2.Weights.Layer,\n        biases: net_pb2.Weights.Layer,\n    ) -> None:\n        self.tensor(nnx_dict[\"scale\"], scales)\n        self.tensor(nnx_dict[\"bias\"], biases)\n\n    def policy_heads(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights.PolicyHeads\n    ) -> None:\n        if \"policy_embedding_shared\" in nnx_dict:\n            self.matmul(\n                nnx_dict[\"policy_embedding_shared\"],\n                weights.ip_pol_w,\n                weights.ip_pol_b,\n            )\n        policy_heads_dict = nnx_dict[\"policy_heads\"]\n        for head_name in [\"vanilla\", \"optimistic_st\", \"soft\", \"opponent\"]:\n            if head_name in policy_heads_dict:\n                self.policy_head(\n                    policy_heads_dict[head_name], getattr(weights, head_name)\n                )\n\n    def policy_head(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights.PolicyHead\n    ) -> None:\n        if \"tokens\" in nnx_dict:\n            self.matmul(nnx_dict[\"tokens\"], weights.ip_pol_w, weights.ip_pol_b)\n        self.matmul(nnx_dict[\"q\"], weights.ip2_pol_w, weights.ip2_pol_b)\n        self.matmul(nnx_dict[\"k\"], weights.ip3_pol_w, weights.ip3_pol_b)\n        self.matmul(nnx_dict[\"promotion_dense\"], weights.ip4_pol_w, None)\n\n    def value_head(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights.ValueHead\n    ) -> None:\n        self.matmul(nnx_dict[\"embed\"], weights.ip_val_w, weights.ip_val_b)\n        self.matmul(nnx_dict[\"dense1\"], weights.ip1_val_w, weights.ip1_val_b)\n        self.matmul(nnx_dict[\"wdl\"], weights.ip2_val_w, weights.ip2_val_b)\n        if \"error\" in nnx_dict:\n            self.matmul(\n                nnx_dict[\"error\"], weights.ip_val_err_w, weights.ip_val_err_b\n            )\n        if \"categorical\" in nnx_dict:\n            self.matmul(\n                nnx_dict[\"categorical\"],\n                weights.ip_val_cat_w,\n                weights.ip_val_cat_b,\n            )\n\n    def movesleft_head(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        self.matmul(nnx_dict[\"embed\"], weights.ip_mov_w, weights.ip_mov_b)\n        self.matmul(nnx_dict[\"dense1\"], weights.ip1_mov_w, weights.ip1_mov_b)\n        self.matmul(nnx_dict[\"out\"], weights.ip2_mov_w, weights.ip2_mov_b)\n\n    def ffn(self, nnx_dict: nnx.State, ffn: net_pb2.Weights.FFN) -> None:\n        self.matmul(nnx_dict[\"linear1\"], ffn.dense1_w, ffn.dense1_b)\n        self.matmul(nnx_dict[\"linear2\"], ffn.dense2_w, ffn.dense2_b)\n\n    def matmul(\n        self,\n        nnx_dict: nnx.State,\n        weights: net_pb2.Weights.Layer,\n        biases: Optional[net_pb2.Weights.Layer],\n    ) -> None:\n        self.tensor(nnx_dict[\"kernel\"], weights)\n        if biases:\n            self.tensor(nnx_dict[\"bias\"], biases)\n        else:\n            assert \"bias\" not in nnx_dict\n\n    def tensor(\n        self,\n        param: Any,\n        leela: net_pb2.Weights.Layer,\n    ) -> None:\n        print(\n            param.shape,\n            len(leela.params) // 2,\n            math.prod(param.shape),\n        )\n        assert len(leela.params) // 2 == math.prod(param.shape)\n        assert len(leela.params) != 0\n"
  },
  {
    "path": "src/lczero_training/convert/leela_to_jax.py",
    "content": "import dataclasses\nimport gzip\nimport logging\nimport math\nfrom typing import Optional, cast\n\nimport jax.numpy as jnp\nfrom flax import nnx, serialization\n\nfrom lczero_training.model.model import LczeroModel\nfrom proto import hlo_pb2, net_pb2\n\nfrom .jax_to_leela import LeelaExportOptions, jax_to_leela\nfrom .leela_pytree_visitor import LeelaPytreeWeightsVisitor\nfrom .leela_to_modelconfig import leela_to_modelconfig\n\nlogger = logging.getLogger(__name__)\n\n\n_EMBEDDING_PLANE_TO_SCALE = 109\n_EMBEDDING_SCALE = 99.0\n\n\n@dataclasses.dataclass\nclass LeelaImportOptions:\n    weights_dtype: hlo_pb2.XlaShapeProto.Type\n    compute_dtype: hlo_pb2.XlaShapeProto.Type\n\n\ndef fix_older_weights_file(file: net_pb2.Net) -> None:\n    nf = net_pb2.NetworkFormat\n    has_network_format = file.format.HasField(\"network_format\")\n    network_format = (\n        file.format.network_format.network if has_network_format else None\n    )\n\n    net = file.format.network_format\n\n    if not has_network_format:\n        # Older protobufs don't have format definition.\n        net.input = nf.INPUT_CLASSICAL_112_PLANE\n        net.output = nf.OUTPUT_CLASSICAL\n        net.network = nf.NETWORK_CLASSICAL_WITH_HEADFORMAT\n        net.value = nf.VALUE_CLASSICAL\n        net.policy = nf.POLICY_CLASSICAL\n    elif network_format == nf.NETWORK_CLASSICAL:\n        # Populate policyFormat and valueFormat fields in old protobufs\n        # without these fields.\n        net.network = nf.NETWORK_CLASSICAL_WITH_HEADFORMAT\n        net.value = nf.VALUE_CLASSICAL\n        net.policy = nf.POLICY_CLASSICAL\n    elif network_format == nf.NETWORK_SE:\n        net.network = nf.NETWORK_SE_WITH_HEADFORMAT\n        net.value = nf.VALUE_CLASSICAL\n        net.policy = nf.POLICY_CLASSICAL\n    elif (\n        network_format == nf.NETWORK_SE_WITH_HEADFORMAT\n        and len(file.weights.encoder) > 0\n    ):\n        # Attention body network made with old protobuf.\n        net.network = nf.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT\n        if file.weights.HasField(\"smolgen_w\"):\n            # Need to override activation defaults for smolgen.\n            net.ffn_activation = nf.ACTIVATION_RELU_2\n            net.smolgen_activation = nf.ACTIVATION_SWISH\n    elif network_format == nf.NETWORK_AB_LEGACY_WITH_MULTIHEADFORMAT:\n        net.network = nf.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT\n\n    if (\n        file.format.network_format.network\n        == nf.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT\n    ):\n        weights = file.weights\n        if weights.HasField(\"policy_heads\") and weights.HasField(\"value_heads\"):\n            logger.info(\n                \"Weights file has multihead format, updating format flag\"\n            )\n            net.network = nf.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT\n            net.input_embedding = nf.INPUT_EMBEDDING_PE_DENSE\n        if not file.format.network_format.HasField(\"input_embedding\"):\n            net.input_embedding = nf.INPUT_EMBEDDING_PE_MAP\n\n\nclass LeelaToJax(LeelaPytreeWeightsVisitor):\n    def embedding_block(\n        self, nnx_dict: nnx.State, weights: net_pb2.Weights\n    ) -> None:\n        super().embedding_block(nnx_dict=nnx_dict, weights=weights)\n        embedding_kernel = cast(nnx.Param, nnx_dict[\"embedding\"][\"kernel\"])\n        values = embedding_kernel.value\n        scaled_values = values.at[_EMBEDDING_PLANE_TO_SCALE].set(\n            values[_EMBEDDING_PLANE_TO_SCALE] * _EMBEDDING_SCALE\n        )\n        embedding_kernel.value = scaled_values\n\n    def tensor(\n        self,\n        param: nnx.Param,\n        leela: net_pb2.Weights.Layer,\n    ) -> None:\n        assert len(leela.params) // 2 == math.prod(param.shape)\n        assert len(leela.params) != 0\n\n        values = jnp.frombuffer(leela.params, dtype=jnp.uint16)\n        values = values.astype(jnp.float32)\n        alpha = values / 65535.0\n        values = alpha * leela.max_val + (1.0 - alpha) * leela.min_val\n        values = values.astype(param.dtype)\n        values = values.reshape(param.shape[::-1]).transpose()\n        param.value = values\n\n\ndef leela_to_jax(\n    leela_net: net_pb2.Net, import_options: LeelaImportOptions\n) -> nnx.State:\n    config = leela_to_modelconfig(\n        leela_net,\n        import_options.weights_dtype,\n        import_options.compute_dtype,\n    )\n\n    model = LczeroModel(config=config, rngs=nnx.Rngs(params=42))\n    state = nnx.state(model)\n    visitor = LeelaToJax(state, leela_net)\n    visitor.run()\n\n    return state\n\n\ndef leela_to_jax_files(\n    input_path: str,\n    weights_dtype: str,\n    compute_dtype: str,\n    output_modelconfig: Optional[str],\n    output_serialized_jax: Optional[str],\n    output_leela_verification: Optional[str],\n    print_modelconfig: bool = False,\n) -> None:\n    lc0_weights = net_pb2.Net()\n    with gzip.open(input_path, \"rb\") as f:\n        contents = f.read()\n        assert isinstance(contents, bytes)\n        lc0_weights.ParseFromString(contents)\n\n    fix_older_weights_file(lc0_weights)\n\n    import_options = LeelaImportOptions(\n        weights_dtype=getattr(hlo_pb2.XlaShapeProto, weights_dtype),\n        compute_dtype=getattr(hlo_pb2.XlaShapeProto, compute_dtype),\n    )\n\n    config = leela_to_modelconfig(\n        lc0_weights,\n        import_options.weights_dtype,\n        import_options.compute_dtype,\n    )\n\n    if print_modelconfig:\n        print(config)\n\n    if output_modelconfig:\n        with open(output_modelconfig, \"w\") as f:\n            f.write(str(config))\n\n    if output_serialized_jax is None and output_leela_verification is None:\n        return\n\n    state = leela_to_jax(lc0_weights, import_options)\n\n    if output_serialized_jax:\n        with open(output_serialized_jax, \"wb\") as f:\n            f.write(serialization.to_bytes(state))\n\n    if output_leela_verification:\n        min_version = (\n            f\"v{lc0_weights.min_version.major}.\"\n            f\"{lc0_weights.min_version.minor}.\"\n            f\"{lc0_weights.min_version.patch}\"\n        )\n        license_str = (\n            lc0_weights.license if lc0_weights.HasField(\"license\") else None\n        )\n        export_options = LeelaExportOptions(\n            min_version=min_version,\n            num_heads=lc0_weights.weights.headcount,\n            license=license_str,\n            training_steps=lc0_weights.training_params.training_steps,\n        )\n        verification_net = jax_to_leela(state, export_options)\n        with gzip.open(output_leela_verification, \"wb\") as f:\n            f.write(verification_net.SerializeToString())\n"
  },
  {
    "path": "src/lczero_training/convert/leela_to_modelconfig.py",
    "content": "from proto import hlo_pb2, model_config_pb2, net_pb2\n\n\ndef _defaultactivation_to_activation(\n    activation: net_pb2.NetworkFormat.DefaultActivation,\n) -> net_pb2.NetworkFormat.ActivationFunction:\n    return {\n        net_pb2.NetworkFormat.DEFAULT_ACTIVATION_RELU: net_pb2.NetworkFormat.ACTIVATION_RELU,\n        net_pb2.NetworkFormat.DEFAULT_ACTIVATION_MISH: net_pb2.NetworkFormat.ACTIVATION_MISH,\n    }[activation]\n\n\ndef leela_to_modelconfig(\n    leela_net: net_pb2.Net,\n    weights_dtype: hlo_pb2.XlaShapeProto.Type,\n    compute_dtype: hlo_pb2.XlaShapeProto.Type,\n) -> model_config_pb2.ModelConfig:\n    assert weights_dtype == hlo_pb2.XlaShapeProto.F32, (\n        \"Only float32 weights are supported.\"\n    )\n    assert leela_net.format.weights_encoding == net_pb2.Format.LINEAR16\n    leela_net_format = leela_net.format.network_format\n    model_config = model_config_pb2.ModelConfig()\n\n    model_config.defaults.compute_dtype = compute_dtype\n    model_config.defaults.activation = _defaultactivation_to_activation(\n        leela_net_format.default_activation\n    )\n    model_config.defaults.ffn_activation = (\n        leela_net_format.ffn_activation or model_config.defaults.activation\n    )\n    assert (\n        leela_net_format.input_embedding\n        == net_pb2.NetworkFormat.INPUT_EMBEDDING_PE_DENSE\n    ), \"Only dense positional embedding is supported, got {}\".format(\n        net_pb2.NetworkFormat.InputEmbeddingFormat.Name(\n            leela_net_format.input_embedding\n        )\n    )\n    assert leela_net_format.policy == net_pb2.NetworkFormat.POLICY_ATTENTION, (\n        \"Only attention policy is supported, got {}\".format(\n            net_pb2.NetworkFormat.PolicyFormat.Name(leela_net_format.policy)\n        )\n    )\n    assert leela_net_format.value == net_pb2.NetworkFormat.VALUE_WDL, (\n        \"Only WDL value is supported, got {}\".format(\n            net_pb2.NetworkFormat.ValueFormat.Name(leela_net_format.value)\n        )\n    )\n    assert leela_net_format.moves_left == net_pb2.NetworkFormat.MOVES_LEFT_V1, (\n        \"Only V1 moves left format is supported, got {}\".format(\n            net_pb2.NetworkFormat.MovesLeftFormat.Name(\n                leela_net_format.moves_left\n            )\n        )\n    )\n\n    def size(x: net_pb2.Weights.Layer) -> int:\n        return len(x.params) // 2\n\n    assert (\n        leela_net_format.network\n        == net_pb2.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT\n    )\n    weights = leela_net.weights\n    model_config.embedding.dense_size = size(weights.ip_emb_preproc_b) // 64\n    model_config.embedding.embedding_size = size(weights.ip_emb_b)\n    assert size(weights.ip_mult_gate) > 0\n    assert size(weights.ip_add_gate) > 0\n    model_config.embedding.dff = size(weights.ip_emb_ffn.dense1_b)\n\n    model_config.encoder.num_blocks = len(weights.encoder)\n    assert model_config.encoder.num_blocks > 0\n    encoder = weights.encoder[0]\n    model_config.encoder.d_model = size(encoder.mha.q_b)\n    model_config.encoder.heads = weights.headcount\n    model_config.encoder.dff = size(encoder.ffn.dense1_b)\n\n    if weights.HasField(\"smolgen_w\"):\n        model_config.encoder.smolgen.activation = (\n            leela_net_format.smolgen_activation\n            or model_config.defaults.activation\n        )\n        model_config.encoder.smolgen.hidden_channels = (\n            size(encoder.mha.smolgen.compress)\n            // model_config.embedding.embedding_size\n        )\n        model_config.encoder.smolgen.gen_size = (\n            size(encoder.mha.smolgen.dense2_b) // weights.headcount\n        )\n        model_config.encoder.smolgen.hidden_size = size(\n            encoder.mha.smolgen.dense1_b\n        )\n\n    if weights.policy_heads.HasField(\"ip_pol_w\"):\n        model_config.shared_policy_embedding_size = size(\n            weights.policy_heads.ip_pol_b\n        )\n\n    for head_name in [\"vanilla\", \"optimistic_st\", \"soft\", \"opponent\"]:\n        if weights.policy_heads.HasField(head_name):\n            head = getattr(weights.policy_heads, head_name)\n            assert size(head.ip2_pol_b) > 0\n            assert not head.HasField(\"ip_pol_w\")\n            policy_head = model_config.policy_head.add()\n            policy_head.name = head_name\n            if not model_config.HasField(\"shared_policy_embedding_size\"):\n                policy_head.embedding_size = size(head.ip_pol_b)\n            policy_head.d_model = size(head.ip2_pol_b)\n\n    for head_name in [\"winner\", \"q\", \"st\"]:\n        if weights.value_heads.HasField(head_name):\n            head = getattr(weights.value_heads, head_name)\n            assert size(head.ip_val_b) > 0\n            value_head = model_config.value_head.add()\n            value_head.name = head_name\n            value_head.num_channels = size(head.ip_val_b)\n            if head.HasField(\"ip_val_err_w\"):\n                value_head.has_error_output = True\n            if head.HasField(\"ip_val_cat_b\"):\n                value_head.num_categorical_buckets = size(head.ip_val_cat_b)\n\n    movesleft_head = model_config.movesleft_head.add()\n    movesleft_head.name = \"main\"\n    movesleft_head.num_channels = size(weights.ip_mov_b)\n\n    return model_config\n"
  },
  {
    "path": "src/lczero_training/daemon/__init__.py",
    "content": "# ABOUTME: Daemon package for training subprocess communication.\n# ABOUTME: Provides TrainingDaemon class for IPC via stdin/stdout.\n"
  },
  {
    "path": "src/lczero_training/daemon/daemon.py",
    "content": "import logging\nimport signal\nimport sys\nimport threading\nimport time\n\nimport anyio\n\nimport proto.training_metrics_pb2 as training_metrics_pb2\n\nfrom .pipeline import TrainingPipeline\nfrom .protocol.communicator import Communicator\nfrom .protocol.messages import (\n    StartTrainingImmediatelyPayload,\n    StartTrainingPayload,\n    TrainingStatusPayload,\n)\n\n\nclass TrainingDaemon:\n    _training_pipeline: TrainingPipeline | None = None\n    _config_filepath: str | None = None\n    _daemon_start_time: float\n\n    def __init__(self, memory_profile_dir: str | None = None) -> None:\n        self._memory_profile_dir = memory_profile_dir\n        self._daemon_start_time = time.time()\n        self._setup_logging()\n        self._setup_signal_handling()\n        self._communicator = Communicator(self, sys.stdin, sys.stdout)\n        self._communicator_thread = threading.Thread(\n            target=lambda: self._communicator.run(), daemon=True\n        )\n        self._communicator_thread.start()\n        self._async_thread = threading.Thread(\n            target=lambda: anyio.run(self._metrics_main), daemon=True\n        )\n        self._async_thread.start()\n        self._signal_thread = threading.Thread(\n            target=self._signal_handler_thread, daemon=True\n        )\n        self._signal_thread.start()\n\n    def _setup_logging(self) -> None:\n        logging.basicConfig(\n            level=logging.INFO,\n            format=(\n                \"%(levelname).1s%(asctime)s.%(msecs)03d %(name)s \"\n                \"%(filename)s:%(lineno)d] %(message)s\"\n            ),\n            datefmt=\"%m%d %H:%M:%S\",\n            stream=sys.stderr,\n        )\n        logging.info(\"TrainingDaemon starting up\")\n\n    def _setup_signal_handling(self) -> None:\n        # Block SIGINT and SIGTERM on all threads\n        signal.pthread_sigmask(\n            signal.SIG_BLOCK, {signal.SIGINT, signal.SIGTERM}\n        )\n\n    def _signal_handler_thread(self) -> None:\n        # Wait for SIGINT or SIGTERM\n        signum = signal.sigwait({signal.SIGINT, signal.SIGTERM})\n        self._shutdown(signum)\n\n    def _shutdown(self, signum: int) -> None:\n        logging.info(f\"Received signal {signum}, shutting down...\")\n        if self._training_pipeline:\n            self._training_pipeline.stop()\n\n    async def _metrics_main(self) -> None:\n        async with anyio.create_task_group() as tg:\n            tg.start_soon(self._metrics_task)\n\n    async def _metrics_task(self) -> None:\n        while True:\n            await anyio.sleep(1.1)\n\n            dataloader_1_second = None\n            dataloader_total = None\n            dataloader_update_secs = None\n            training_schedule_data = None\n\n            data_loader = None\n            if self._training_pipeline:\n                data_loader = self._training_pipeline.get_data_loader()\n                training_schedule_data = (\n                    self._training_pipeline.get_training_schedule_data(\n                        self._daemon_start_time\n                    )\n                )\n\n            if data_loader is not None:\n                stats_1_second_bytes, _ = data_loader.get_bucket_metrics(\n                    0, False\n                )  # k1Second = 0\n                stats_total_bytes, dataloader_update_secs = (\n                    data_loader.get_aggregate_ending_now(float(\"inf\"), False)\n                )\n                dataloader_1_second = (\n                    training_metrics_pb2.DataLoaderMetricsProto()\n                )\n                dataloader_1_second.ParseFromString(stats_1_second_bytes)\n                dataloader_total = training_metrics_pb2.DataLoaderMetricsProto()\n                dataloader_total.ParseFromString(stats_total_bytes)\n\n            payload = TrainingStatusPayload(\n                dataloader_update_secs=dataloader_update_secs,\n                dataloader_1_second=dataloader_1_second,\n                dataloader_total=dataloader_total,\n                training_schedule=training_schedule_data,\n            )\n            self._communicator.send(payload)\n\n    def run(self) -> None:\n        while self._config_filepath is None:\n            logging.info(\"Waiting for training config...\")\n            time.sleep(1)\n\n        logging.info(\"Config received. Starting training pipeline.\")\n        self._training_pipeline = TrainingPipeline(\n            self._config_filepath,\n            memory_profile_dir=self._memory_profile_dir,\n        )\n        self._training_pipeline.run()\n\n    def on_start_training(self, payload: StartTrainingPayload) -> None:\n        self._config_filepath = payload.config_filepath\n\n    def on_start_training_immediately(\n        self, payload: StartTrainingImmediatelyPayload\n    ) -> None:\n        if not self._training_pipeline:\n            logging.warning(\n                \"Received immediate training request before pipeline initialization.\"\n            )\n            return\n\n        self._training_pipeline.start_training_immediately()\n"
  },
  {
    "path": "src/lczero_training/daemon/metrics.py",
    "content": "\"\"\"Metrics collection and logging for training daemon.\"\"\"\n\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Callable, Dict, Optional\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import nnx\n\nfrom lczero_training._lczero_training import DataLoader\nfrom lczero_training.daemon.metrics_base import _Metric\nfrom lczero_training.daemon.rms_metrics import _RmsMetric\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.training.state import JitTrainingState, TrainingSample\nfrom lczero_training.training.tensorboard import TensorboardLogger\nfrom lczero_training.training.training import StepHookData\nfrom proto.metrics_config_pb2 import MetricConfig, MetricsConfig\n\nlogger = logging.getLogger(__name__)\n\n# Type alias for batch tuple returned by DataLoader\n# Using ... to allow variable length for compatibility with maybe_get_next\nBatchTuple = tuple[np.ndarray, ...]\n\n\n@dataclass\nclass CachedBatch:\n    \"\"\"Cached batch data with the global step when it was last updated.\"\"\"\n\n    batch: BatchTuple\n    global_step: int\n\n\ndef load_batch_from_npz(npz_filename: str) -> BatchTuple:\n    \"\"\"Load a batch from an NPZ file.\n\n    Args:\n        npz_filename: Path to the NPZ file.\n\n    Returns:\n        BatchTuple (tuple of inputs, probabilities, values arrays).\n\n    Raises:\n        ValueError: If the NPZ file doesn't contain exactly one batch.\n    \"\"\"\n    with np.load(npz_filename, allow_pickle=True) as npz_file:\n        batches = npz_file[\"batches\"]\n        if batches.size != 1:\n            raise ValueError(\n                f\"Expected 1 batch in npz '{npz_filename}', got {batches.size}\"\n            )\n        return batches[0]\n\n\nclass _TrainingBatchMetric(_Metric):\n    \"\"\"Metric that logs training batch data.\"\"\"\n\n    def __init__(self, config: MetricConfig, logger: TensorboardLogger):\n        super().__init__(config, logger)\n        if config.use_swa_model:\n            raise ValueError(\n                f\"Metric '{config.name}': Cannot use SWA model for \"\n                \"training_batch metrics\"\n            )\n\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        self.logger.log(hook_data.global_step, hook_data.metrics)\n\n\nclass _EvaluatingMetric(_Metric, ABC):\n    \"\"\"Base class for metrics that evaluate loss on data.\"\"\"\n\n    def __init__(\n        self,\n        config: MetricConfig,\n        logger: TensorboardLogger,\n        loss_fn: Optional[LczeroLoss],\n    ):\n        super().__init__(config, logger)\n        if not loss_fn:\n            raise ValueError(f\"Metric '{config.name}': Loss function required\")\n        self.loss_fn = loss_fn\n\n    @abstractmethod\n    def get_batch(self) -> BatchTuple:\n        \"\"\"Get the batch data to evaluate.\"\"\"\n\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        batch = self.get_batch()\n        metrics = self._evaluate(batch, hook_data.jit_state, graphdef)\n        self.logger.log(hook_data.global_step, metrics)\n\n    def _evaluate(\n        self,\n        batch: BatchTuple,\n        jit_state: JitTrainingState,\n        graphdef: nnx.GraphDef,\n    ) -> Dict[str, jax.Array]:\n        model_state = (\n            jit_state.swa_state\n            if self.config.use_swa_model\n            else jit_state.model_state\n        )\n        if model_state is None:\n            raise RuntimeError(\"SWA state not available\")\n        batch_sample = TrainingSample(\n            inputs=jnp.asarray(batch[0]),\n            probabilities=jnp.asarray(batch[1]),\n            values=jnp.asarray(batch[2]),\n        )\n        return _make_eval_jit(graphdef, self.loss_fn)(model_state, batch_sample)\n\n\n_EvalJit = Callable[[nnx.State, TrainingSample], Dict[str, jax.Array]]\n_eval_jit_cache: dict[tuple[int, int], _EvalJit] = {}\n\n\ndef _make_eval_jit(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> _EvalJit:\n    key = (id(graphdef), id(loss_fn))\n    if key not in _eval_jit_cache:\n\n        @jax.jit\n        def _eval(\n            model_state: nnx.State, batch_sample: TrainingSample\n        ) -> Dict[str, jax.Array]:\n            model = nnx.merge(graphdef, model_state)\n            loss_vfn = jax.vmap(loss_fn, in_axes=(None, 0), out_axes=0)\n            per_sample_loss, unweighted = loss_vfn(model, batch_sample)\n            return {\n                \"loss\": jnp.mean(per_sample_loss),\n                \"unweighted_losses\": jax.tree_util.tree_map(\n                    jnp.mean, unweighted\n                ),\n            }\n\n        _eval_jit_cache[key] = _eval\n    return _eval_jit_cache[key]\n\n\ndef evaluate_batch(\n    batch: BatchTuple,\n    jit_state: JitTrainingState,\n    graphdef: nnx.GraphDef,\n    loss_fn: LczeroLoss,\n    use_swa_model: bool = False,\n) -> Dict[str, jax.Array]:\n    \"\"\"Evaluate loss function on a batch of data.\n\n    Args:\n        batch: BatchTuple (inputs, probabilities, values).\n        jit_state: JIT training state containing model and optimizer state.\n        graphdef: Graph definition of the model.\n        loss_fn: Loss function to evaluate.\n        use_swa_model: If True, use SWA model state instead of regular model.\n\n    Returns:\n        Dictionary of metrics with loss and unweighted losses.\n    \"\"\"\n    model_state = (\n        jit_state.swa_state if use_swa_model else jit_state.model_state\n    )\n    if model_state is None:\n        raise RuntimeError(\"SWA state not available\")\n    batch_sample = TrainingSample(\n        inputs=jnp.asarray(batch[0]),\n        probabilities=jnp.asarray(batch[1]),\n        values=jnp.asarray(batch[2]),\n    )\n    return _make_eval_jit(graphdef, loss_fn)(model_state, batch_sample)\n\n\nclass _DataLoaderMetric(_EvaluatingMetric):\n    \"\"\"Metric that evaluates loss on dataloader output.\"\"\"\n\n    def __init__(\n        self,\n        config: MetricConfig,\n        logger: TensorboardLogger,\n        loss_fn: Optional[LczeroLoss],\n        data_loader: Optional[DataLoader],\n        dataloader_name: str,\n        cached_batches: Dict[str, CachedBatch],\n    ):\n        super().__init__(config, logger, loss_fn)\n        if not data_loader:\n            raise ValueError(f\"Metric '{config.name}': DataLoader required\")\n        self.data_loader = data_loader\n        self.dataloader_name = dataloader_name\n        self.cached_batches = cached_batches\n\n    def get_batch(self) -> BatchTuple:\n        return self.cached_batches[self.dataloader_name].batch\n\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        \"\"\"Update cache if needed and log the metric.\"\"\"\n        cached = self.cached_batches.get(self.dataloader_name)\n        if cached is None or cached.global_step != hook_data.global_step:\n            batch = self.data_loader.maybe_get_next(self.dataloader_name)\n            if batch is not None:\n                self.cached_batches[self.dataloader_name] = CachedBatch(\n                    batch, hook_data.global_step\n                )\n            elif cached is None:\n                raise RuntimeError(\n                    f\"No data for metric '{self.config.name}' \"\n                    f\"from dataloader '{self.dataloader_name}'\"\n                )\n\n        super().log(hook_data, graphdef)\n\n\nclass _NpzMetric(_EvaluatingMetric):\n    \"\"\"Metric that evaluates loss on pre-loaded NPZ data.\"\"\"\n\n    def __init__(\n        self,\n        config: MetricConfig,\n        logger: TensorboardLogger,\n        loss_fn: Optional[LczeroLoss],\n        npz_filename: str,\n        cached_batches: Dict[str, CachedBatch],\n    ):\n        super().__init__(config, logger, loss_fn)\n        self.npz_filename = npz_filename\n        self.cached_batches = cached_batches\n\n    def get_batch(self) -> BatchTuple:\n        return self.cached_batches[self.npz_filename].batch\n\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        \"\"\"Load NPZ data if needed and log the metric.\"\"\"\n        if self.npz_filename not in self.cached_batches:\n            batch = load_batch_from_npz(self.npz_filename)\n            self.cached_batches[self.npz_filename] = CachedBatch(\n                batch, hook_data.global_step\n            )\n\n        super().log(hook_data, graphdef)\n\n\nclass Metrics:\n    \"\"\"Manages metrics collection and logging for training.\"\"\"\n\n    def __init__(\n        self,\n        config: MetricsConfig,\n        loss_fn: Optional[LczeroLoss] = None,\n        data_loader: Optional[DataLoader] = None,\n    ):\n        self._metrics: Dict[str, _Metric] = {}\n        self._cached_batches: Dict[str, CachedBatch] = {}\n\n        for mc in config.metric:\n            tb_logger = TensorboardLogger(\n                os.path.join(config.tensorboard_path, mc.name)\n            )\n\n            metric: _Metric\n            if mc.HasField(\"training_batch\"):\n                metric = _TrainingBatchMetric(mc, tb_logger)\n            elif mc.HasField(\"dataloader_output\"):\n                metric = _DataLoaderMetric(\n                    mc,\n                    tb_logger,\n                    loss_fn,\n                    data_loader,\n                    mc.dataloader_output,\n                    self._cached_batches,\n                )\n            elif mc.HasField(\"npz_filename\"):\n                metric = _NpzMetric(\n                    mc,\n                    tb_logger,\n                    loss_fn,\n                    mc.npz_filename,\n                    self._cached_batches,\n                )\n            elif mc.HasField(\"weights\"):\n                if mc.weights.rms:\n                    metric = _RmsMetric(mc, tb_logger)\n                else:\n                    raise ValueError(\n                        f\"Metric '{mc.name}': No weight metric type specified\"\n                    )\n            else:\n                raise ValueError(f\"Metric '{mc.name}' has no sample source\")\n\n            self._metrics[mc.name] = metric\n\n    def on_step(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        \"\"\"Process metrics for the current step.\"\"\"\n        for metric in self._metrics.values():\n            if metric.should_log(\n                hook_data.global_step,\n                hook_data.local_step,\n                hook_data.steps_per_epoch,\n            ):\n                metric.log(hook_data, graphdef)\n\n    def close(self) -> None:\n        \"\"\"Close all TensorBoard loggers.\"\"\"\n        for metric in self._metrics.values():\n            metric.logger.close()\n"
  },
  {
    "path": "src/lczero_training/daemon/metrics_base.py",
    "content": "\"\"\"Base classes for metrics.\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom flax import nnx\n\nfrom lczero_training.training.tensorboard import TensorboardLogger\nfrom lczero_training.training.training import StepHookData\nfrom proto.metrics_config_pb2 import MetricConfig\n\n\nclass _Metric(ABC):\n    \"\"\"Base class for individual metric tracking.\"\"\"\n\n    def __init__(self, config: MetricConfig, logger: TensorboardLogger):\n        self.config = config\n        self.logger = logger\n\n    def should_log(\n        self, global_step: int, local_step: int, steps_per_epoch: int\n    ) -> bool:\n        \"\"\"Check if it's time to log this metric.\"\"\"\n        if self.config.after_epoch and local_step + 1 == steps_per_epoch:\n            return True\n        step = global_step if self.config.use_global_steps else local_step\n        return (step + 1) % self.config.period == 0\n\n    @abstractmethod\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        \"\"\"Log the metric for the current step.\"\"\"\n"
  },
  {
    "path": "src/lczero_training/daemon/pipeline.py",
    "content": "import dataclasses\nimport datetime\nimport gzip\nimport logging\nimport os\nimport threading\nimport time\nfrom pathlib import Path\nfrom typing import cast\n\nimport jax\nimport orbax.checkpoint as ocp\nimport requests\nfrom dotenv import load_dotenv\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training._lczero_training import DataLoader\nfrom lczero_training.convert.jax_to_leela import (\n    LeelaExportOptions,\n    jax_to_leela,\n)\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.lr_schedule import make_lr_schedule\nfrom lczero_training.training.optimizer import make_gradient_transformation\nfrom lczero_training.training.state import JitTrainingState, TrainingState\nfrom lczero_training.training.training import (\n    StepHookData,\n    Training,\n    from_dataloader,\n)\nfrom proto.data_loader_config_pb2 import DataLoaderConfig\nfrom proto.root_config_pb2 import RootConfig\nfrom proto.stage_control_pb2 import StageControlRequest, StageControlResponse\nfrom proto.training_config_pb2 import ScheduleConfig\n\nfrom .metrics import Metrics\nfrom .protocol.messages import TrainingScheduleData, TrainingStage\n\nlogger = logging.getLogger(__name__)\n\n\ndef _read_config_file(config_filepath: str) -> RootConfig:\n    config_path = Path(config_filepath)\n    config_text = config_path.read_text()\n\n    root_config = RootConfig()\n    text_format.Parse(config_text, root_config)\n    return root_config\n\n\ndef _make_dataloader(config: DataLoaderConfig) -> DataLoader:\n    config_bytes = config.SerializeToString()\n    return DataLoader(config_bytes)\n\n\ndef _configure_file_logging(config: RootConfig) -> None:\n    \"\"\"Configure file logging if log_filename is specified in config.\"\"\"\n    if config.HasField(\"log_filename\"):\n        file_handler = logging.FileHandler(config.log_filename)\n        file_handler.setFormatter(\n            logging.Formatter(\n                \"%(levelname).1s%(asctime)s.%(msecs)03d %(name)s \"\n                \"%(filename)s:%(lineno)d] %(message)s\",\n                datefmt=\"%m%d %H:%M:%S\",\n            )\n        )\n        logging.getLogger().addHandler(file_handler)\n        logger.info(f\"Added file logging to {config.log_filename}\")\n\n\ndef _log_jax_system_info() -> None:\n    \"\"\"Log JAX system information including devices and backend details.\"\"\"\n    devices = jax.devices()\n    local_devices = jax.local_devices()\n    device_counts: dict[str, int] = {}\n    for device in devices:\n        device_type = device.device_kind\n        device_counts[device_type] = device_counts.get(device_type, 0) + 1\n\n    logger.info(f\"JAX Backend: {jax.default_backend()}\")\n    logger.info(\n        f\"JAX Devices: {len(devices)} total, {len(local_devices)} local\"\n    )\n    for device_type, count in device_counts.items():\n        logger.info(f\"  {device_type}: {count}\")\n    for i, device in enumerate(local_devices):\n        logger.info(f\"  Local device {i}: {device}\")\n\n\n@dataclasses.dataclass\nclass _TrainingCycleState:\n    start_time: float = dataclasses.field(default_factory=time.time)\n    current_stage: TrainingStage = TrainingStage.WAITING_FOR_DATA\n    completed_epochs: int = 0\n    current_cycle_start_time: float = dataclasses.field(\n        default_factory=time.time\n    )\n    current_training_start_time: float | None = None\n    previous_training_duration: float = 0.0\n    previous_cycle_duration: float = 0.0\n    chunks_at_training_start: int = 0\n\n\nclass TrainingPipeline:\n    _data_loader: DataLoader\n    _schedule: ScheduleConfig\n    _chunks_to_wait: int\n    _model: LczeroModel\n    _checkpoint_mgr: ocp.CheckpointManager\n    _training_state: TrainingState\n    _cycle_state: _TrainingCycleState\n    _metrics: Metrics | None\n\n    def __init__(\n        self,\n        config_filepath: str,\n        memory_profile_dir: str | None = None,\n    ) -> None:\n        self._memory_profile_dir = memory_profile_dir\n        logger.info(f\"Loading config from {config_filepath}\")\n        self._config = self._load_config(config_filepath)\n        _configure_file_logging(self._config)\n        self._schedule = self._config.training.schedule\n        self._chunks_per_network = self._schedule.chunks_per_network\n        self._num_steps_per_epoch = self._schedule.steps_per_network\n        self._chunks_to_wait = self._chunks_per_network\n        self._cycle_state = _TrainingCycleState()\n        self._force_training_event = threading.Event()\n        self._metrics = None\n        logger.info(\"Creating empty model\")\n        self._model = LczeroModel(self._config.model, rngs=nnx.Rngs(params=42))\n        self._graphdef = nnx.graphdef(self._model)\n        logger.info(\n            f\"Creating checkpoint manager at {self._config.training.checkpoint.path}\"\n        )\n        self._checkpoint_mgr = ocp.CheckpointManager(\n            self._config.training.checkpoint.path,\n            options=ocp.CheckpointManagerOptions(\n                max_to_keep=self._config.training.checkpoint.max_to_keep\n                or None,\n            ),\n        )\n\n        logger.info(\"Restoring checkpoint\")\n        optimizer_config = self._config.training.optimizer\n        max_grad_norm = getattr(self._config.training, \"max_grad_norm\", 0.0)\n        self._lr_schedule = make_lr_schedule(self._config.training.lr_schedule)\n        optimizer_tx = make_gradient_transformation(\n            optimizer_config,\n            max_grad_norm=max_grad_norm,\n            lr_schedule=self._lr_schedule,\n        )\n        model_state = nnx.state(self._model)\n        jit_state = JitTrainingState(\n            step=0,\n            model_state=model_state,\n            opt_state=optimizer_tx.init(model_state),\n            swa_state=model_state,\n            num_averages=0.0,\n        )\n        empty_state = TrainingState(\n            jit_state=jit_state,\n            num_heads=self._config.model.encoder.heads,\n        )\n        self._training_state = cast(\n            TrainingState,\n            self._checkpoint_mgr.restore(\n                step=None,\n                args=ocp.args.PyTreeRestore(\n                    item=empty_state,\n                ),\n            ),\n        )\n\n        logger.info(\"Creating training session\")\n        loss_fn = LczeroLoss(config=self._config.training.losses)\n        self._training = Training(\n            optimizer_tx=make_gradient_transformation(\n                self._config.training.optimizer,\n                max_grad_norm=max_grad_norm,\n                lr_schedule=self._lr_schedule,\n            ),\n            graphdef=nnx.graphdef(self._model),\n            loss_fn=loss_fn,\n            swa_config=(\n                self._config.training.swa\n                if self._config.training.HasField(\"swa\")\n                else None\n            ),\n        )\n\n        logger.info(\"Creating data loader\")\n        self._data_loader = _make_dataloader(self._config.data_loader)\n        self._set_chunk_anchor(self._training_state.last_chunk_source)\n\n        # Create metrics if configured.\n        if self._config.HasField(\"metrics\"):\n            logger.info(\"Creating metrics\")\n            self._metrics = Metrics(\n                config=self._config.metrics,\n                loss_fn=loss_fn,\n                data_loader=self._data_loader,\n            )\n        else:\n            logger.info(\"No metrics configured\")\n\n        _log_jax_system_info()\n\n    def start_training_immediately(self) -> None:\n        \"\"\"Request the next training cycle to start without waiting for chunks.\"\"\"\n\n        logger.info(\"Received request to start training immediately.\")\n        self._force_training_event.set()\n\n    def run(self) -> None:\n        logging.info(\"Starting DataLoader\")\n        self._data_loader.start()\n\n        while True:\n            self._wait_for_chunks()\n            new_anchor, used_chunks = self._reset_chunk_anchor()\n            logging.info(f\"{new_anchor=} {used_chunks=}\")\n            self._training_state = self._training_state.replace(\n                last_chunk_source=new_anchor\n            )\n            self._chunks_to_wait = max(\n                self._chunks_to_wait + self._chunks_per_network - used_chunks,\n                self._chunks_per_network // 2,\n            )\n            self._train_one_network()\n            self._save_checkpoint()\n            network_bytes = self._export_network()\n            if network_bytes:\n                self._save_network(network_bytes)\n                self._upload_network(network_bytes)\n\n    def _export_network(self) -> bytes | None:\n        if (\n            not self._config.export.destination_filename\n            and not self._config.export.HasField(\"upload_training_run\")\n        ):\n            return None\n\n        logging.info(\"Exporting network\")\n\n        options = LeelaExportOptions(\n            min_version=\"0.31\",\n            num_heads=self._training_state.num_heads,\n            license=None,\n            training_steps=self._training_state.jit_state.step,\n        )\n        export_state = (\n            self._training_state.jit_state.swa_state\n            if self._config.export.export_swa_model\n            else self._training_state.jit_state.model_state\n        )\n        assert isinstance(export_state, nnx.State)\n        net = jax_to_leela(jax_weights=export_state, export_options=options)\n        return gzip.compress(net.SerializeToString())\n\n    def _save_network(self, network_bytes: bytes) -> None:\n        date_str = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n        step = self._training_state.jit_state.step\n\n        for destination_template in self._config.export.destination_filename:\n            destination = destination_template.format(\n                datetime=date_str, step=step\n            )\n            logging.info(f\"Writing network to {destination}\")\n            os.makedirs(os.path.dirname(destination), exist_ok=True)\n            with open(destination, \"wb\") as f:\n                f.write(network_bytes)\n            logging.info(f\"Finished writing network to {destination}\")\n\n    def _upload_network(self, network_bytes: bytes) -> None:\n        if not self._config.export.HasField(\"upload_training_run\"):\n            return\n\n        load_dotenv()\n        upload_pwd = os.getenv(\"UPLOAD_PWD\")\n        if not upload_pwd:\n            logging.error(\n                \"UPLOAD_PWD not found in environment variables, skipping upload.\"\n            )\n            return\n\n        try:\n            state = cast(nnx.State, nnx.state(self._model))\n            layers = len(state[\"encoders\"][\"encoders\"][\"layers\"])\n            filters = state[\"embedding\"][\"embedding\"][\"bias\"].shape[0]\n            training_id = self._config.export.upload_training_run\n\n            logging.info(\n                f\"Uploading network to training website (ID: {training_id}, \"\n                f\"layers: {layers}, filters: {filters})\"\n            )\n\n            data = {\n                \"pwd\": upload_pwd,\n                \"training_id\": training_id,\n                \"layers\": layers,\n                \"filters\": filters,\n            }\n            response = requests.post(\n                \"http://api.lczero.org/upload_network\",\n                files={\"file\": network_bytes},\n                data=data,\n            )\n            response.raise_for_status()\n\n            logging.info(f\"Successfully uploaded network: {response.text}\")\n        except requests.exceptions.RequestException as e:\n            logging.error(f\"Failed to upload network: {e}\")\n        except (KeyError, AttributeError, IndexError) as e:\n            logging.error(f\"Failed to extract model metadata for upload: {e}\")\n\n    def _step_hook(self, hook_data: StepHookData) -> None:\n        # Append current learning rate from schedule to metrics.\n        hook_data.metrics[\"lr\"] = self._lr_schedule(hook_data.global_step)\n        if self._metrics is not None:\n            self._metrics.on_step(hook_data, self._graphdef)\n\n    def _train_one_network(self) -> None:\n        logging.info(\"Training one network!\")\n\n        # Record training start\n        self._cycle_state.current_training_start_time = time.time()\n        self._cycle_state.current_stage = TrainingStage.TRAINING\n        self._cycle_state.chunks_at_training_start = self._chunks_since_anchor()\n\n        new_jit_state = self._training.run(\n            jit_state=self._training_state.jit_state,\n            datagen=from_dataloader(self._data_loader),\n            num_steps=self._schedule.steps_per_network,\n            step_hook=self._step_hook,\n            memory_profile_dir=self._memory_profile_dir,\n        )\n        self._training_state = self._training_state.replace(\n            jit_state=new_jit_state\n        )\n\n        # Record training end\n        current_time = time.time()\n        if self._cycle_state.current_training_start_time:\n            self._cycle_state.previous_training_duration = (\n                current_time - self._cycle_state.current_training_start_time\n            )\n        self._cycle_state.previous_cycle_duration = (\n            current_time - self._cycle_state.current_cycle_start_time\n        )\n        self._cycle_state.completed_epochs += 1\n        self._cycle_state.current_training_start_time = None\n        self._cycle_state.current_stage = TrainingStage.WAITING_FOR_DATA\n        self._cycle_state.current_cycle_start_time = current_time\n\n        logging.info(\"Done training\")\n\n    def _save_checkpoint(self) -> None:\n        logging.info(\"Saving checkpoint\")\n        self._checkpoint_mgr.save(\n            step=self._training_state.jit_state.step,\n            args=ocp.args.PyTreeSave(item=self._training_state),\n        )\n        logging.info(\"Checkpoint saved\")\n\n    def stop(self) -> None:\n        self._data_loader.stop()\n        if self._metrics is not None:\n            self._metrics.close()\n\n    def get_data_loader(self) -> DataLoader:\n        return self._data_loader\n\n    def _wait_for_chunks(self) -> None:\n        current_chunks = self._chunks_since_anchor()\n        logger.info(\n            f\"Waiting for {self._chunks_to_wait} chunks. \"\n            f\"got {current_chunks} so far\"\n        )\n        while True:\n            if self._force_training_event.is_set():\n                logger.info(\n                    \"Force start requested; skipping remaining chunk wait.\"\n                )\n                self._force_training_event.clear()\n                self._chunks_to_wait = self._chunks_since_anchor()\n                return\n\n            if self._chunks_since_anchor() >= self._chunks_to_wait:\n                logger.info(\"Done waiting for enough chunks\")\n                return\n\n            time.sleep(1)\n\n    def get_training_schedule_data(\n        self, daemon_start_time: float\n    ) -> TrainingScheduleData:\n        \"\"\"Return current training schedule data for TUI display.\"\"\"\n        current_time = time.time()\n\n        # Calculate current training time if currently training\n        current_training_time = 0.0\n        if self._cycle_state.current_training_start_time is not None:\n            current_training_time = (\n                current_time - self._cycle_state.current_training_start_time\n            )\n\n        # Calculate current cycle time\n        current_cycle_time = (\n            current_time - self._cycle_state.current_cycle_start_time\n        )\n\n        # Calculate new chunks since training start\n        new_chunks_since_training_start = max(\n            0,\n            self._chunks_since_anchor()\n            - self._cycle_state.chunks_at_training_start,\n        )\n\n        return TrainingScheduleData(\n            current_stage=self._cycle_state.current_stage,\n            completed_epochs_since_start=self._cycle_state.completed_epochs,\n            new_chunks_since_training_start=new_chunks_since_training_start,\n            chunks_to_wait=self._chunks_to_wait,\n            total_uptime_seconds=current_time - daemon_start_time,\n            current_training_time_seconds=current_training_time,\n            previous_training_time_seconds=self._cycle_state.previous_training_duration,\n            current_cycle_time_seconds=current_cycle_time,\n            previous_cycle_time_seconds=self._cycle_state.previous_cycle_duration,\n        )\n\n    def _send_chunk_pool_control(\n        self, request: StageControlRequest\n    ) -> StageControlResponse | None:\n        responses = self._data_loader.send_control_message(request)\n        for _, response in responses:\n            if response.HasField(\"chunk_pool_response\"):\n                return response\n        return None\n\n    def _reset_chunk_anchor(self) -> tuple[str, int]:\n        request = StageControlRequest()\n        request.chunk_pool_request.reset_chunk_anchor = True\n        response = self._send_chunk_pool_control(request)\n        if not response or not response.HasField(\"chunk_pool_response\"):\n            return \"\", 0\n        chunk_response = response.chunk_pool_response\n        return chunk_response.chunk_anchor, chunk_response.chunks_since_anchor\n\n    def _chunks_since_anchor(self) -> int:\n        request = StageControlRequest()\n        request.chunk_pool_request.SetInParent()\n        response = self._send_chunk_pool_control(request)\n        if not response or not response.HasField(\"chunk_pool_response\"):\n            return 0\n        return response.chunk_pool_response.chunks_since_anchor\n\n    def _set_chunk_anchor(self, anchor: str) -> None:\n        request = StageControlRequest()\n        request.chunk_pool_request.set_chunk_anchor = anchor or \"\"\n        self._send_chunk_pool_control(request)\n\n    def _load_config(self, config_filepath: str) -> RootConfig:\n        config_path = Path(config_filepath)\n        config_text = config_path.read_text()\n\n        root_config = RootConfig()\n        text_format.Parse(config_text, root_config)\n        return root_config\n"
  },
  {
    "path": "src/lczero_training/daemon/protocol/__init__.py",
    "content": "# ABOUTME: Protocol package for JSONL IPC communication between processes.\n# ABOUTME: Contains registry system, message definitions, and communicator class.\n"
  },
  {
    "path": "src/lczero_training/daemon/protocol/communicator.py",
    "content": "# ABOUTME: Core Communicator class for JSONL IPC between processes.\n# ABOUTME: Handles serialization/deserialization and message dispatch via stdin/stdout.\n\nimport json\nimport types\nfrom dataclasses import is_dataclass\nfrom enum import Enum\nfrom typing import Any, Optional, TextIO, Union, get_args, get_origin\n\nfrom anyio.streams.text import TextReceiveStream, TextSendStream\nfrom google.protobuf.json_format import MessageToDict, ParseDict\nfrom google.protobuf.message import Message\n\nfrom .registry import CLASS_TO_TYPE_MAP, TYPE_TO_CLASS_MAP\n\n\ndef _to_serializable(obj: Any) -> Any:\n    \"\"\"Convert dataclass/protobuf objects to JSON-serializable dicts.\"\"\"\n    if isinstance(obj, Message):\n        return MessageToDict(\n            obj, preserving_proto_field_name=True, use_integers_for_enums=True\n        )\n    elif isinstance(obj, Enum):\n        return obj.value\n    elif is_dataclass(obj):\n        return {\n            f.name: _to_serializable(getattr(obj, f.name))\n            for f in obj.__dataclass_fields__.values()\n            if getattr(obj, f.name) is not None\n        }\n    elif isinstance(obj, (list, tuple)):\n        return [_to_serializable(item) for item in obj]\n    elif isinstance(obj, dict):\n        return {k: _to_serializable(v) for k, v in obj.items()}\n    else:\n        return obj\n\n\ndef _unwrap_optional(t: Any) -> Any:\n    \"\"\"Extract T from T | None or Union[T, None].\"\"\"\n    if isinstance(t, types.UnionType) or get_origin(t) is Union:\n        args = [a for a in get_args(t) if a is not type(None)]\n        return args[0] if len(args) == 1 else t\n    return t\n\n\ndef _is_protobuf(cls: type) -> bool:\n    \"\"\"Check if cls is a protobuf Message class.\"\"\"\n    try:\n        return isinstance(cls, type) and issubclass(cls, Message)\n    except TypeError:\n        return False\n\n\ndef _from_serializable(cls: type, data: Any) -> Any:\n    \"\"\"Reconstruct dataclass/protobuf from dict.\"\"\"\n    if _is_protobuf(cls):\n        instance = cls()\n        ParseDict(data, instance)\n        return instance\n\n    if not is_dataclass(cls):\n        return data\n\n    args = {}\n    for field in cls.__dataclass_fields__.values():\n        if field.name not in data:\n            continue\n\n        value = data[field.name]\n        field_type = _unwrap_optional(field.type)\n\n        if get_origin(field_type) is list:\n            item_type = get_args(field_type)[0]\n            if is_dataclass(item_type) or _is_protobuf(item_type):\n                value = [_from_serializable(item_type, item) for item in value]\n        elif is_dataclass(field_type) or _is_protobuf(field_type):\n            value = _from_serializable(field_type, value)\n        elif isinstance(field_type, type) and issubclass(field_type, Enum):\n            # Convert string value back to enum\n            value = field_type(value)\n\n        args[field.name] = value\n\n    return cls(**args)\n\n\nclass Communicator:\n    def __init__(\n        self, handler: Any, input_stream: TextIO, output_stream: TextIO\n    ) -> None:\n        \"\"\"\n        Initializes the Communicator.\n\n        Args:\n            handler: An object with `on_<event_type>` methods.\n            input_stream: A file-like object to read incoming messages from (e.g., sys.stdin).\n            output_stream: A file-like object to write outgoing messages to (e.g., sys.stdout).\n        \"\"\"\n        self.handler = handler\n        self.input = input_stream\n        self.output = output_stream\n\n    def send(self, payload_instance: Any) -> None:\n        \"\"\"\n        Serializes and sends a payload object as a notification.\n        The event type is automatically looked up from the registry.\n        \"\"\"\n        payload_cls = type(payload_instance)\n        event_type = CLASS_TO_TYPE_MAP.get(payload_cls)\n\n        if event_type is None:\n            raise TypeError(\n                f\"Object of type {payload_cls.__name__} is not a registered payload.\"\n            )\n\n        payload_dict = _to_serializable(payload_instance)\n        message = {\"type\": event_type, \"payload\": payload_dict}\n\n        self.output.write(json.dumps(message) + \"\\n\")\n        self.output.flush()\n\n    def _dispatch(self, line: str) -> None:\n        line = line.strip()\n        if not line:\n            return\n\n        data = json.loads(line)\n        event_type = data[\"type\"]\n        payload_dict = data[\"payload\"]\n\n        payload_cls = TYPE_TO_CLASS_MAP[event_type]\n        payload_instance = _from_serializable(payload_cls, payload_dict)\n\n        handler_method_name = f\"on_{event_type}\"\n        handler_method = getattr(self.handler, handler_method_name)\n\n        handler_method(payload_instance)\n\n    def run(self) -> None:\n        \"\"\"\n        Starts the blocking listener loop.\n        Reads from the input stream line-by-line, deserializes notifications,\n        and dispatches them to the appropriate handler method.\n\n        This method blocks until the input stream is closed.\n        \"\"\"\n        for line in self.input:\n            self._dispatch(line)\n\n\nclass AsyncCommunicator:\n    def __init__(\n        self,\n        handler: Any,\n        input_stream: TextReceiveStream,\n        output_stream: TextSendStream,\n        io_dump: Optional[TextIO] = None,\n    ) -> None:\n        \"\"\"\n        Initializes the AsyncCommunicator.\n\n        Args:\n            handler: An object with async `on_<event_type>` methods.\n            input_stream: A TextReceiveStream to read incoming messages from.\n            output_stream: A TextSendStream to write outgoing messages to.\n            io_dump: Optional file to dump raw IO for debugging.\n        \"\"\"\n        self.handler = handler\n        self.input_stream = input_stream\n        self.output_stream = output_stream\n        self._io_dump = io_dump\n        self._buffer = \"\"\n\n    async def send(self, payload_instance: Any) -> None:\n        \"\"\"\n        Serializes and sends a payload object as a notification.\n        The event type is automatically looked up from the registry.\n        \"\"\"\n        payload_cls = type(payload_instance)\n        event_type = CLASS_TO_TYPE_MAP.get(payload_cls)\n\n        if event_type is None:\n            raise TypeError(\n                f\"Object of type {payload_cls.__name__} is not a registered payload.\"\n            )\n\n        payload_dict = _to_serializable(payload_instance)\n        message = {\"type\": event_type, \"payload\": payload_dict}\n\n        message_line = json.dumps(message) + \"\\n\"\n        if self._io_dump:\n            self._io_dump.write(f\"> {message_line}\")\n        await self.output_stream.send(message_line)\n\n    async def _dispatch(self, line: str) -> None:\n        line = line.strip()\n        if not line:\n            return\n\n        data = json.loads(line)\n        event_type = data[\"type\"]\n        payload_dict = data[\"payload\"]\n\n        payload_cls = TYPE_TO_CLASS_MAP[event_type]\n        payload_instance = _from_serializable(payload_cls, payload_dict)\n\n        handler_method_name = f\"on_{event_type}\"\n        handler_method = getattr(self.handler, handler_method_name)\n\n        await handler_method(payload_instance)\n\n    async def run(self) -> None:\n        \"\"\"\n        Starts the async listener loop.\n        Reads from the input stream line-by-line, deserializes notifications,\n        and dispatches them to the appropriate async handler method.\n\n        This method runs until the input stream is closed.\n        \"\"\"\n        async for chunk in self.input_stream:\n            self._buffer += chunk\n            while \"\\n\" in self._buffer:\n                line, self._buffer = self._buffer.split(\"\\n\", 1)\n                if self._io_dump and line:\n                    self._io_dump.write(f\"< {line}\\n\")\n                await self._dispatch(line)\n\n        if self._buffer:\n            await self._dispatch(self._buffer)\n"
  },
  {
    "path": "src/lczero_training/daemon/protocol/messages.py",
    "content": "# ABOUTME: Payload dataclass definitions for JSONL IPC protocol messages.\n# ABOUTME: Defines minimal event types for training daemon communication.\n\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Optional\n\nimport proto.training_metrics_pb2 as training_metrics_pb2\n\nfrom .registry import register\n\n\nclass TrainingStage(Enum):\n    WAITING_FOR_DATA = \"WAITING FOR DATA\"\n    TRAINING = \"TRAINING\"\n\n\n@dataclass\nclass TrainingScheduleData:\n    current_stage: TrainingStage\n    completed_epochs_since_start: int\n    new_chunks_since_training_start: int\n    chunks_to_wait: int\n    total_uptime_seconds: float\n    current_training_time_seconds: float\n    previous_training_time_seconds: float\n    current_cycle_time_seconds: float\n    previous_cycle_time_seconds: float\n\n\n# --- Notifications from UI (Parent) to Trainer (Child) ---\n\n\n@register(\"start_training\")\n@dataclass\nclass StartTrainingPayload:\n    config_filepath: str\n\n\n@register(\"start_training_immediately\")\n@dataclass\nclass StartTrainingImmediatelyPayload:\n    pass\n\n\n# --- Notifications from Trainer (Child) to UI (Parent) ---\n\n\n@register(\"training_status\")\n@dataclass\nclass TrainingStatusPayload:\n    dataloader_update_secs: Optional[float] = None\n    dataloader_1_second: Optional[\n        training_metrics_pb2.DataLoaderMetricsProto\n    ] = None\n    dataloader_total: Optional[training_metrics_pb2.DataLoaderMetricsProto] = (\n        None\n    )\n    training_schedule: Optional[TrainingScheduleData] = None\n"
  },
  {
    "path": "src/lczero_training/daemon/protocol/registry.py",
    "content": "# ABOUTME: Registry system for mapping event type strings to payload dataclasses.\n# ABOUTME: Provides @register decorator and maintains bidirectional mapping dicts.\n\nimport inspect\nfrom typing import Callable\n\n# These maps will be populated by the @register decorator\nTYPE_TO_CLASS_MAP = {}\nCLASS_TO_TYPE_MAP = {}\n\n\ndef register(event_type: str) -> Callable[[type], type]:\n    \"\"\"A decorator to register a payload dataclass with its event type string.\"\"\"\n\n    def decorator(cls: type) -> type:\n        if not inspect.isclass(cls):\n            raise TypeError(\n                \"The @register decorator can only be used on classes.\"\n            )\n\n        if event_type in TYPE_TO_CLASS_MAP:\n            raise ValueError(\n                f\"Event type '{event_type}' is already registered.\"\n            )\n\n        if cls in CLASS_TO_TYPE_MAP:\n            raise ValueError(f\"Class '{cls.__name__}' is already registered.\")\n\n        TYPE_TO_CLASS_MAP[event_type] = cls\n        CLASS_TO_TYPE_MAP[cls] = event_type\n        return cls\n\n    return decorator\n"
  },
  {
    "path": "src/lczero_training/daemon/rms_metrics.py",
    "content": "\"\"\"RMS metrics for model parameters.\"\"\"\n\nfrom typing import Any, cast\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom lczero_training.daemon.metrics_base import _Metric\nfrom lczero_training.model.encoder import EncoderBlock\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.tensorboard import TensorboardLogger\nfrom lczero_training.training.training import StepHookData\nfrom proto.metrics_config_pb2 import MetricConfig\n\n\n@jax.jit\ndef compute_rms(state_subtree: nnx.State) -> jax.Array:\n    \"\"\"Compute RMS of all parameters in a state subtree.\"\"\"\n    leaves = jax.tree_util.tree_leaves(state_subtree)\n    total_sq = sum(jnp.sum(jnp.square(p)) for p in leaves)\n    total_n = sum(p.size for p in leaves)\n    return jnp.sqrt(total_sq / total_n)\n\n\ndef extract_attention_components(model: LczeroModel) -> dict[str, Any]:\n    \"\"\"Extract Q, K, V, output_dense, smolgen from all encoder layers.\n\n    Args:\n        model: LczeroModel instance.\n\n    Returns:\n        Dict with keys 'q', 'k', 'v', 'output_dense', optionally 'smolgen'.\n    \"\"\"\n    components: dict[str, Any] = {\n        \"q\": {},\n        \"k\": {},\n        \"v\": {},\n        \"output_dense\": {},\n    }\n\n    encoder_layers = cast(list[EncoderBlock], model.encoders.encoders.layers)\n    for i, encoder_block in enumerate(encoder_layers):\n        mha = encoder_block.mha\n        components[\"q\"][f\"layer_{i}\"] = nnx.state(mha.q)\n        components[\"k\"][f\"layer_{i}\"] = nnx.state(mha.k)\n        components[\"v\"][f\"layer_{i}\"] = nnx.state(mha.v)\n        components[\"output_dense\"][f\"layer_{i}\"] = nnx.state(mha.output_dense)\n\n        if mha.smolgen is not None:\n            if \"smolgen\" not in components:\n                components[\"smolgen\"] = {}\n            components[\"smolgen\"][f\"layer_{i}\"] = nnx.state(mha.smolgen)\n\n    return components\n\n\ndef collect_rms_metrics(model: LczeroModel) -> dict[str, Any]:\n    \"\"\"Collect all RMS metrics for the model.\n\n    Args:\n        model: LczeroModel instance.\n\n    Returns:\n        Nested dict with RMS values for different model components.\n    \"\"\"\n    model_state = nnx.state(model)\n\n    metrics: dict[str, Any] = {\n        \"all_params\": compute_rms(model_state),\n        \"embedding\": compute_rms(nnx.state(model.embedding)),\n        \"encoder_body\": compute_rms(nnx.state(model.encoders)),\n    }\n\n    # Attention components\n    attn_components = extract_attention_components(model)\n    metrics[\"attention\"] = {\n        name: compute_rms(component)\n        for name, component in attn_components.items()\n    }\n\n    # Policy heads\n    metrics[\"policy_heads\"] = {\n        name: compute_rms(nnx.state(head))\n        for name, head in model.policy_heads.items()\n    }\n\n    # Value heads\n    metrics[\"value_heads\"] = {\n        name: compute_rms(nnx.state(head))\n        for name, head in model.value_heads.items()\n    }\n\n    # Movesleft heads\n    metrics[\"movesleft_heads\"] = {\n        name: compute_rms(nnx.state(head))\n        for name, head in model.movesleft_heads.items()\n    }\n\n    return metrics\n\n\nclass _RmsMetric(_Metric):\n    \"\"\"Metric that computes RMS of model parameters.\"\"\"\n\n    def __init__(self, config: MetricConfig, logger: TensorboardLogger):\n        super().__init__(config, logger)\n\n    def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:\n        model_state = (\n            hook_data.jit_state.swa_state\n            if self.config.use_swa_model\n            else hook_data.jit_state.model_state\n        )\n        model = nnx.merge(graphdef, model_state)\n        metrics = collect_rms_metrics(model)\n        self.logger.log(hook_data.global_step, metrics)\n"
  },
  {
    "path": "src/lczero_training/dataloader/__init__.py",
    "content": "from lczero_training._lczero_training import (\n    DataLoader,\n    TensorBase,\n)\nfrom proto.data_loader_config_pb2 import DataLoaderConfig\n\n__all__ = [\"DataLoader\", \"make_dataloader\", \"TensorBase\"]\n\n\ndef make_dataloader(config: DataLoaderConfig) -> DataLoader:\n    loader = DataLoader(config)\n    loader.start()\n    return loader\n"
  },
  {
    "path": "src/lczero_training/model/__init__.py",
    "content": ""
  },
  {
    "path": "src/lczero_training/model/embedding.py",
    "content": "import jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .shared import Ffn\nfrom .utils import get_activation\n\n\nclass Embedding(nnx.Module):\n    \"\"\"Computes embeddings for the input features.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        input_channels: int,\n        config: model_config_pb2.EmbeddingConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        deepnorm_alpha: float,\n        deepnorm_beta: float,\n        rngs: nnx.Rngs,\n    ):\n        self._input_channels = input_channels\n        dense_size = config.dense_size\n        embedding_size = config.embedding_size\n        self.activation = defaults.activation\n\n        assert dense_size > 0\n        self.preprocess = nnx.Linear(\n            in_features=64 * 12,\n            out_features=64 * dense_size,\n            rngs=rngs,\n        )\n\n        assert embedding_size > 0\n        self.embedding = nnx.Linear(\n            in_features=input_channels + dense_size,\n            out_features=embedding_size,\n            rngs=rngs,\n        )\n        self.norm = nnx.LayerNorm(embedding_size, epsilon=1e-3, rngs=rngs)\n        self.ma_gating = MaGating(feature_shape=(64, embedding_size), rngs=rngs)\n        self.deepnorm_alpha = deepnorm_alpha\n        self.ffn = Ffn(\n            in_features=embedding_size,\n            hidden_features=config.dff,\n            hidden_activation=defaults.ffn_activation,\n            deepnorm_beta=deepnorm_beta,\n            rngs=rngs,\n        )\n        self.out_norm = nnx.LayerNorm(embedding_size, epsilon=1e-3, rngs=rngs)\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        # Preprocess positional info and concatenate to input.\n        pos_info = self.preprocess(x[..., :12].flatten()).reshape((64, -1))\n        x = jnp.concatenate([x, pos_info], axis=1)\n\n        # Square embedding.\n        x = self.embedding(x)\n        x = get_activation(self.activation)(x)\n        x = self.norm(x)\n        x = self.ma_gating(x)\n        # FFN block with residual connection and layer norm.\n        x = x + self.ffn(x) * self.deepnorm_alpha\n        x = self.out_norm(x)\n        return x\n\n\nclass MaGating(nnx.Module):\n    \"\"\"Applies multiplicative and additive gating.\"\"\"\n\n    def __init__(self, feature_shape: tuple[int, ...], *, rngs: nnx.Rngs):\n        self.mult_gate = Gating(\n            feature_shape=feature_shape, additive=False, rngs=rngs\n        )\n        self.add_gate = Gating(\n            feature_shape=feature_shape, additive=True, rngs=rngs\n        )\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        return self.add_gate(self.mult_gate(x))\n\n\nclass Gating(nnx.Module):\n    def __init__(\n        self,\n        feature_shape: tuple[int, ...],\n        additive: bool = True,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        self.additive = additive\n        init_val = 0.0 if self.additive else 1.0\n        self.gate = nnx.Param(\n            jnp.full(feature_shape, init_val, dtype=jnp.float32)\n        )\n\n    def __call__(self, inputs: jax.Array) -> jax.Array:\n        if self.additive:\n            return inputs + self.gate.value\n\n        effective_gate = jax.nn.relu(self.gate.value)\n        return inputs * effective_gate\n"
  },
  {
    "path": "src/lczero_training/model/encoder.py",
    "content": "import math\nfrom typing import Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\nfrom flax.linen import initializers as flax_initializers\n\nfrom proto import model_config_pb2\n\nfrom .shared import Ffn\nfrom .utils import get_activation\n\n\nclass EncoderTower(nnx.Module):\n    def __init__(\n        self,\n        *,\n        in_features: int,\n        config: model_config_pb2.EncoderConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        deepnorm_beta: float,\n        rngs: nnx.Rngs,\n    ):\n        smolgen_shared_gen_dense = None\n        assert config.HasField(\"smolgen\")\n        if config.HasField(\"smolgen\"):\n            smolgen_shared_gen_dense = nnx.Linear(\n                in_features=config.smolgen.gen_size,\n                out_features=64 * 64,\n                use_bias=False,\n                rngs=rngs,\n            )\n\n        self.encoders = nnx.Sequential(\n            *[\n                EncoderBlock(\n                    in_features=in_features,\n                    config=config,\n                    defaults=defaults,\n                    smol_gen_dense=smolgen_shared_gen_dense,\n                    deepnorm_beta=deepnorm_beta,\n                    rngs=rngs,\n                )\n                for _ in range(config.num_blocks)\n            ]\n        )\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        return self.encoders(x)\n\n\nclass EncoderBlock(nnx.Module):\n    \"\"\"A single block of the transformer encoder.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        in_features: int,\n        config: model_config_pb2.EncoderConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        smol_gen_dense: Optional[nnx.Linear],\n        deepnorm_beta: float,\n        rngs: nnx.Rngs,\n    ):\n        assert (smol_gen_dense is not None) == config.HasField(\"smolgen\")\n        self.mha = MultiHeadAttention(\n            in_features=in_features,\n            config=config,\n            defaults=defaults,\n            smol_gen_dense=smol_gen_dense,\n            deepnorm_beta=deepnorm_beta,\n            rngs=rngs,\n        )\n\n        self.alpha = math.pow(2.0 * config.num_blocks, -0.25)\n        self.ln1 = nnx.LayerNorm(in_features, epsilon=1e-3, rngs=rngs)\n        self.ffn = Ffn(\n            in_features=in_features,\n            hidden_features=config.dff,\n            hidden_activation=defaults.ffn_activation,\n            deepnorm_beta=deepnorm_beta,\n            rngs=rngs,\n        )\n        self.ln2 = nnx.LayerNorm(in_features, epsilon=1e-3, rngs=rngs)\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        x = x + self.mha(x) * self.alpha\n        out1 = self.ln1(x)\n        ffn_out = self.ffn(out1)\n        return self.ln2(out1 + ffn_out * self.alpha)\n\n\nclass MultiHeadAttention(nnx.Module):\n    \"\"\"Multi-head attention module.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        config: model_config_pb2.EncoderConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        smol_gen_dense: Optional[nnx.Linear],\n        deepnorm_beta: float,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        depth = config.d_model\n        assert depth % config.heads == 0, (\n            \"Model depth must be divisible by the number of heads.\"\n        )\n        self.activation = defaults.activation\n        self.depth = depth\n        self.num_heads = config.heads\n        self.q = nnx.Linear(\n            in_features=in_features, out_features=depth, rngs=rngs\n        )\n        self.k = nnx.Linear(\n            in_features=in_features, out_features=depth, rngs=rngs\n        )\n        deepnorm_init = flax_initializers.variance_scaling(\n            scale=deepnorm_beta,\n            mode=\"fan_avg\",\n            distribution=\"truncated_normal\",\n        )\n\n        self.v = nnx.Linear(\n            in_features=in_features,\n            out_features=depth,\n            kernel_init=deepnorm_init,\n            rngs=rngs,\n        )\n        self.output_dense = nnx.Linear(\n            in_features=depth,\n            out_features=in_features,\n            kernel_init=deepnorm_init,\n            rngs=rngs,\n        )\n\n        assert (smol_gen_dense is not None) == config.HasField(\"smolgen\")\n        self.smolgen: Optional[Smolgen]\n        if smol_gen_dense is not None:\n            self.smolgen = Smolgen(\n                in_features=in_features,\n                config=config.smolgen,\n                defaults=defaults,\n                heads=config.heads,\n                weight_gen_dense=smol_gen_dense,\n                rngs=rngs,\n            )\n        else:\n            self.smolgen = None\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        q, k, v = self.q(x), self.k(x), self.v(x)\n\n        head_depth = self.depth // self.num_heads\n        # Reshape for multi-head attention.\n        q, k, v = (\n            t.reshape((-1, self.num_heads, head_depth)).transpose((1, 0, 2))\n            for t in (q, k, v)\n        )\n\n        # Scaled dot-product attention.\n        logits = jnp.einsum(\"...qd,...kd->...qk\", q, k)\n        logits /= jnp.sqrt(k.shape[-1]).astype(k.dtype)\n\n        if self.smolgen is not None:\n            logits += self.smolgen(x)\n\n        attention_weights = nnx.softmax(logits, axis=-1)\n        scaled_attention = jnp.matmul(attention_weights, v)\n\n        # Reshape back to original dimensions.\n        scaled_attention = scaled_attention.transpose((1, 0, 2)).reshape(\n            (-1, self.depth)\n        )\n        return self.output_dense(scaled_attention)\n\n\nclass Smolgen(nnx.Module):\n    \"\"\"Smolgen module for generating attention biases.\"\"\"\n\n    def __init__(\n        self,\n        in_features: int,\n        config: model_config_pb2.SmolgenConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        heads: int,\n        weight_gen_dense: nnx.Linear,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        self.heads = heads\n        self.compress = nnx.Linear(\n            in_features=in_features,\n            out_features=config.hidden_channels,\n            use_bias=False,\n            rngs=rngs,\n        )\n        self.dense1 = nnx.Linear(\n            in_features=config.hidden_channels * 64,\n            out_features=config.hidden_size,\n            rngs=rngs,\n        )\n        self.ln1 = nnx.LayerNorm(config.hidden_size, epsilon=1e-3, rngs=rngs)\n\n        self.dense2 = nnx.Linear(\n            in_features=config.hidden_size,\n            out_features=config.gen_size * heads,\n            rngs=rngs,\n        )\n        self.ln2 = nnx.LayerNorm(\n            config.gen_size * heads, epsilon=1e-3, rngs=rngs\n        )\n        self.weight_gen_dense = weight_gen_dense\n        self.activation = config.activation or defaults.activation\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        compressed = self.compress(x).flatten()\n        hidden = self.dense1(compressed)\n        hidden = get_activation(self.activation)(hidden)\n        hidden = self.ln1(hidden)\n\n        gen_from = self.dense2(hidden)\n        gen_from = get_activation(self.activation)(gen_from)\n        gen_from = self.ln2(gen_from)\n        gen_from = gen_from.reshape((self.heads, -1))\n\n        out = self.weight_gen_dense(gen_from)\n        return out.reshape((self.heads, 64, 64))\n"
  },
  {
    "path": "src/lczero_training/model/loss_function.py",
    "content": "from typing import Dict, List, Optional, Sequence, Tuple, Union, cast\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom flax import nnx\nfrom jax.scipy.special import xlogy\n\nfrom lczero_training.training.state import TrainingSample\nfrom lczero_training.training.utils import make_weights_mask\nfrom proto.training_config_pb2 import (\n    LossConfig,\n    MovesLeftLossConfig,\n    PolicyLossConfig,\n    RegularizationLossConfig,\n    ValueCategoricalLossConfig,\n    ValueErrorLossConfig,\n    ValueLossConfig,\n)\n\nfrom .model import LczeroModel, ModelPrediction\n\n\ndef _compute_q_from_wdl(wdl_logits: jax.Array) -> jax.Array:\n    \"\"\"Compute Q value from WDL logits.\"\"\"\n    wdl_probs = jax.nn.softmax(wdl_logits)\n    q_weights = jnp.array([1.0, 0.0, -1.0])\n    return jnp.dot(wdl_probs, q_weights)\n\n\nclass LossBase:\n    def __init__(\n        self,\n        config: Union[\n            PolicyLossConfig,\n            ValueLossConfig,\n            MovesLeftLossConfig,\n            ValueErrorLossConfig,\n            ValueCategoricalLossConfig,\n        ],\n    ) -> None:\n        self.head_name = config.head_name\n        self.metric_name = config.metric_name or config.head_name\n        self.weight = config.weight\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        raise NotImplementedError(\"Subclasses must implement __call__\")\n\n\nclass RegularizationLoss:\n    \"\"\"Computes regularization loss on model parameters.\"\"\"\n\n    def __init__(self, config: RegularizationLossConfig) -> None:\n        self.metric_name = config.metric_name or \"l2\"\n        self.weight = config.weight\n        self._selector = config.selector\n\n    def __call__(self, model: LczeroModel) -> jax.Array:\n        params = nnx.state(model, nnx.Param)\n        mask = make_weights_mask(self._selector, params)\n        masked_params = jax.tree.map(\n            lambda p, m: p.value if m else jnp.zeros_like(p.value),\n            params,\n            mask,\n            is_leaf=lambda x: isinstance(x, nnx.Variable),\n        )\n        leaves = jax.tree.leaves(masked_params)\n        return sum(\n            (jnp.sum(jnp.square(leaf)) for leaf in leaves), jnp.array(0.0)\n        )\n\n\nclass LczeroLoss:\n    policy_losses: List[\"PolicyLoss\"]\n    value_losses: List[\"ValueLoss\"]\n    movesleft_losses: List[\"MovesLeftLoss\"]\n    value_error_losses: List[\"ValueErrorLoss\"]\n    value_categorical_losses: List[\"ValueCategoricalLoss\"]\n    regularization_losses: List[\"RegularizationLoss\"]\n\n    def __init__(self, config: LossConfig) -> None:\n        self.config = config\n        self.policy_losses = [\n            PolicyLoss(loss_config) for loss_config in config.policy\n        ]\n        self.value_losses = [\n            ValueLoss(loss_config) for loss_config in config.value\n        ]\n        self.movesleft_losses = [\n            MovesLeftLoss(loss_config) for loss_config in config.movesleft\n        ]\n        self.value_error_losses = [\n            ValueErrorLoss(loss_config) for loss_config in config.value_error\n        ]\n        self.value_categorical_losses = [\n            ValueCategoricalLoss(loss_config)\n            for loss_config in config.value_categorical\n        ]\n        self.regularization_losses = [\n            RegularizationLoss(loss_config)\n            for loss_config in config.regularization\n        ]\n\n        def _validate_no_duplicate_metrics(\n            loss_type_name: str,\n            losses: Sequence[Union[LossBase, RegularizationLoss]],\n        ) -> None:\n            seen = set()\n            for name in (loss.metric_name for loss in losses):\n                if name in seen:\n                    raise ValueError(\n                        f\"Duplicate metric name: {loss_type_name}/{name}\"\n                    )\n                seen.add(name)\n\n        _validate_no_duplicate_metrics(\"policy\", self.policy_losses)\n        _validate_no_duplicate_metrics(\"value\", self.value_losses)\n        _validate_no_duplicate_metrics(\"movesleft\", self.movesleft_losses)\n        _validate_no_duplicate_metrics(\"value_error\", self.value_error_losses)\n        _validate_no_duplicate_metrics(\n            \"value_categorical\", self.value_categorical_losses\n        )\n        _validate_no_duplicate_metrics(\n            \"regularization\", self.regularization_losses\n        )\n\n    def __call__(\n        self,\n        model: LczeroModel,\n        sample: TrainingSample,\n    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:\n        # Run model forward pass.\n        predictions = model(sample.inputs)\n\n        unweighted_losses: Dict[str, jax.Array] = {}\n        weighted_losses: List[jax.Array] = []\n\n        for policy_loss in self.policy_losses:\n            loss = policy_loss(predictions, sample)\n            unweighted_losses[f\"policy/{policy_loss.metric_name}\"] = loss\n            weighted_losses.append(loss * policy_loss.weight)\n\n        for value_loss in self.value_losses:\n            loss = value_loss(predictions, sample)\n            unweighted_losses[f\"value/{value_loss.metric_name}\"] = loss\n            weighted_losses.append(loss * value_loss.weight)\n\n        for movesleft_loss in self.movesleft_losses:\n            loss = movesleft_loss(predictions, sample)\n            unweighted_losses[f\"movesleft/{movesleft_loss.metric_name}\"] = loss\n            weighted_losses.append(loss * movesleft_loss.weight)\n\n        for value_error_loss in self.value_error_losses:\n            loss = value_error_loss(predictions, sample)\n            unweighted_losses[f\"value_error/{value_error_loss.metric_name}\"] = (\n                loss\n            )\n            weighted_losses.append(loss * value_error_loss.weight)\n\n        for value_categorical_loss in self.value_categorical_losses:\n            loss = value_categorical_loss(predictions, sample)\n            unweighted_losses[\n                f\"value_categorical/{value_categorical_loss.metric_name}\"\n            ] = loss\n            weighted_losses.append(loss * value_categorical_loss.weight)\n\n        for reg_loss in self.regularization_losses:\n            loss = reg_loss(model)\n            unweighted_losses[f\"regularization/{reg_loss.metric_name}\"] = loss\n            weighted_losses.append(loss * reg_loss.weight)\n\n        data_loss = jnp.sum(jnp.array(weighted_losses))\n\n        return data_loss, unweighted_losses\n\n\nclass ValueLoss(LossBase):\n    def __init__(self, config: ValueLossConfig) -> None:\n        super().__init__(config)\n        self.value_type = config.value_type\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        value_pred = predictions.value[self.head_name]\n        value_logits = value_pred[0]\n        # Extract raw q/d from sample and compute WDL.\n        value_q = sample.values[self.value_type, 0]\n        value_d = sample.values[self.value_type, 1]\n        # Compute WDL: w = (1 + q - d) / 2, l = (1 - q - d) / 2\n        value_w = (1.0 + value_q - value_d) / 2.0\n        value_l = (1.0 - value_q - value_d) / 2.0\n        value_wdl = jnp.stack([value_w, value_d, value_l], axis=-1)\n\n        # The cross-entropy between the predicted value and the target value.\n        value_cross_entropy = optax.softmax_cross_entropy(\n            logits=value_logits, labels=jax.lax.stop_gradient(value_wdl)\n        )\n        assert isinstance(value_cross_entropy, jax.Array)\n        return value_cross_entropy\n\n\nclass PolicyLoss(LossBase):\n    def __init__(self, config: PolicyLossConfig):\n        super().__init__(config)\n        self.config = config\n        if config.type == PolicyLossConfig.LOSS_TYPE_UNSPECIFIED:\n            raise ValueError(\n                f\"Policy loss type must be specified for head '{config.head_name}'.\"\n            )\n        self._loss_type = config.type\n        temperature = config.temperature\n        if temperature <= 0:\n            temperature = 1.0\n        self._temperature = temperature\n\n        # Store optimistic config if present.\n        if config.HasField(\"optimistic\"):\n            opt = config.optimistic\n            self.opt_value_head: Optional[str] = opt.value_head_name\n            self.opt_value_type = opt.value_type\n            self.opt_strength = opt.strength\n            self.opt_eps = opt.eps\n            self.opt_alpha = opt.alpha\n            self.opt_propagate_gradients = opt.propagate_value_gradients\n        else:\n            self.opt_value_head = None\n\n    def _apply_temperature_and_normalize(\n        self, policy_targets: jax.Array\n    ) -> jax.Array:\n        if self._temperature == 1.0:\n            return policy_targets\n\n        # Apply temperature scaling.\n        policy_targets = jnp.power(policy_targets, 1.0 / self._temperature)\n\n        # Renormalize after temperature scaling.\n        target_sum = jnp.sum(policy_targets, axis=-1, keepdims=True)\n        safe_sum = jnp.where(\n            target_sum > 0, target_sum, jnp.ones_like(target_sum)\n        )\n        return policy_targets / safe_sum\n\n    def _compute_optimistic_weight(\n        self,\n        value_pred: Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]],\n        target_q: jax.Array,\n    ) -> jax.Array:\n        \"\"\"Compute optimistic policy weight from value head predictions.\"\"\"\n        wdl_logits = value_pred[0]\n        error_pred = value_pred[1]\n        assert error_pred is not None, (\n            \"Error prediction required for optimistic weighting\"\n        )\n\n        # Optionally block gradients to value and error heads.\n        if not self.opt_propagate_gradients:\n            wdl_logits = jax.lax.stop_gradient(wdl_logits)\n            error_pred = jax.lax.stop_gradient(error_pred)\n\n        # Compute predicted Q from WDL.\n        q_pred = _compute_q_from_wdl(wdl_logits)\n\n        # Compute sigma and z-score.\n        sigma = jnp.sqrt(error_pred.squeeze())\n        z = (target_q - q_pred) / (sigma + self.opt_eps)\n\n        # Compute weight.\n        return jax.nn.sigmoid((z - self.opt_strength) * self.opt_alpha)\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        policy_pred = predictions.policy[self.head_name]\n        # Extract probabilities from sample.\n        policy_targets = jnp.asarray(\n            sample.probabilities, dtype=policy_pred.dtype\n        )\n        if self.config.illegal_moves == PolicyLossConfig.MASK:\n            policy_pred = jnp.where(policy_targets >= 0, policy_pred, -jnp.inf)\n\n        # Zero out negative targets for illegal moves.\n        policy_targets = jax.nn.relu(policy_targets)\n\n        # Apply temperature scaling and renormalization if needed.\n        policy_targets = self._apply_temperature_and_normalize(policy_targets)\n\n        cross_entropy = cast(\n            jax.Array,\n            optax.safe_softmax_cross_entropy(\n                logits=policy_pred, labels=policy_targets\n            ),\n        )\n        if self._loss_type == PolicyLossConfig.CROSS_ENTROPY:\n            loss = cross_entropy\n        elif self._loss_type == PolicyLossConfig.KL:\n            loss = cross_entropy + jnp.sum(\n                xlogy(policy_targets, policy_targets), axis=-1\n            )\n        else:\n            raise AssertionError(\n                f\"Unknown policy loss type: {self._loss_type}.\"\n            )\n\n        # Apply optimistic weighting if configured.\n        if self.opt_value_head is not None:\n            value_pred = predictions.value[self.opt_value_head]\n            target_q = sample.values[self.opt_value_type, 0]\n            loss = loss * self._compute_optimistic_weight(value_pred, target_q)\n\n        return loss\n\n\nclass MovesLeftLoss(LossBase):\n    def __init__(self, config: MovesLeftLossConfig) -> None:\n        super().__init__(config)\n        self.value_type = config.value_type\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        movesleft_pred = predictions.movesleft[self.head_name]\n        # Extract movesleft from sample.\n        # sample.values shape: [6, 3], component 2 is movesleft.\n        movesleft_targets = sample.values[self.value_type, 2]\n\n        # Scale the loss to similar range as other losses.\n        scale = 20.0\n        targets = movesleft_targets / scale\n        scaled_predictions = movesleft_pred / scale\n\n        # Huber loss\n        huber_loss = optax.huber_loss(\n            predictions=scaled_predictions, targets=targets, delta=10.0 / scale\n        )\n        assert isinstance(huber_loss, jax.Array)\n        return huber_loss.squeeze()\n\n\nclass ValueErrorLoss(LossBase):\n    def __init__(self, config: ValueErrorLossConfig) -> None:\n        super().__init__(config)\n        self.value_type = config.value_type\n        self.propagate_value_gradients = config.propagate_value_gradients\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        value_pred = predictions.value[self.head_name]\n        wdl_logits = value_pred[0]\n        error_pred = value_pred[1]\n        assert error_pred is not None\n\n        # Convert WDL to Q value.\n        predicted_q = _compute_q_from_wdl(wdl_logits)\n\n        # Get target Q value.\n        target_q = sample.values[self.value_type, 0]\n\n        # Compute actual squared error.\n        actual_squared_error = jnp.square(predicted_q - target_q)\n\n        # Optionally block gradients to WDL head.\n        if not self.propagate_value_gradients:\n            actual_squared_error = jax.lax.stop_gradient(actual_squared_error)\n\n        # MSE between error prediction and actual error.\n        loss = jnp.square(error_pred - actual_squared_error)\n\n        return loss.squeeze()\n\n\nclass ValueCategoricalLoss(LossBase):\n    def __init__(self, config: ValueCategoricalLossConfig) -> None:\n        super().__init__(config)\n        self.value_type = config.value_type\n\n    def __call__(\n        self,\n        predictions: ModelPrediction,\n        sample: TrainingSample,\n    ) -> jax.Array:\n        value_pred = predictions.value[self.head_name]\n        categorical_logits = value_pred[2]\n        assert categorical_logits is not None\n\n        # Get target Q value from sample.\n        target_q = sample.values[self.value_type, 0]\n\n        # Convert Q to bucket index: map [-1, 1) to [0, num_buckets).\n        num_buckets = categorical_logits.shape[-1]\n        bucket_index = jnp.floor((target_q + 1.0) / 2.0 * num_buckets).astype(\n            jnp.int32\n        )\n        bucket_index = jnp.clip(bucket_index, 0, num_buckets - 1)\n\n        # Create one-hot target.\n        target_one_hot = jax.nn.one_hot(bucket_index, num_buckets)\n\n        # Compute softmax cross-entropy.\n        loss = optax.softmax_cross_entropy(\n            logits=categorical_logits,\n            labels=jax.lax.stop_gradient(target_one_hot),\n        )\n        assert isinstance(loss, jax.Array)\n        return loss\n"
  },
  {
    "path": "src/lczero_training/model/model.py",
    "content": "import dataclasses\nimport math\nfrom typing import Optional, Tuple\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .embedding import Embedding\nfrom .encoder import EncoderTower\nfrom .movesleft_head import MovesLeftHead\nfrom .policy_head import PolicyHead\nfrom .utils import get_dtype\nfrom .value_head import ValueHead\n\n\n@jax.tree_util.register_dataclass\n@dataclasses.dataclass\nclass ModelPrediction:\n    \"\"\"Output predictions from LczeroModel.\n\n    Fields:\n        value: Dictionary mapping head names to value prediction tuples.\n        policy: Dictionary mapping head names to policy logits.\n        movesleft: Dictionary mapping head names to moves-left predictions.\n    \"\"\"\n\n    value: dict[str, Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]]\n    policy: dict[str, jax.Array]\n    movesleft: dict[str, jax.Array]\n\n\nclass LczeroModel(nnx.Module):\n    def __init__(self, config: model_config_pb2.ModelConfig, *, rngs: nnx.Rngs):\n        self.config = config\n        self._input_channels = 112\n        deepnorm_beta = math.pow(8.0 * config.encoder.num_blocks, -0.25)\n\n        self.embedding = Embedding(\n            input_channels=self._input_channels,\n            config=config.embedding,\n            defaults=config.defaults,\n            deepnorm_alpha=math.pow(2.0 * config.encoder.num_blocks, -0.25),\n            deepnorm_beta=deepnorm_beta,\n            rngs=rngs,\n        )\n\n        assert self.config.encoder.num_blocks > 0\n\n        self.encoders = EncoderTower(\n            in_features=config.embedding.embedding_size,\n            config=config.encoder,\n            defaults=config.defaults,\n            deepnorm_beta=deepnorm_beta,\n            rngs=rngs,\n        )\n\n        self.value_heads = nnx.Dict(\n            {\n                head_config.name: ValueHead(\n                    in_features=config.embedding.embedding_size,\n                    config=head_config,\n                    defaults=config.defaults,\n                    rngs=rngs,\n                )\n                for head_config in config.value_head\n            }\n        )\n\n        # Named to appear before 'policy_heads' alphabetically in pytree state.\n        # This ensures shared embedding appears at parent level during\n        # serialization.\n        self.policy_embedding_shared: Optional[nnx.Linear]\n        if config.HasField(\"shared_policy_embedding_size\"):\n            self.policy_embedding_shared = nnx.Linear(\n                in_features=config.embedding.embedding_size,\n                out_features=config.shared_policy_embedding_size,\n                rngs=rngs,\n            )\n        else:\n            self.policy_embedding_shared = None\n\n        self.policy_heads = nnx.Dict(\n            {\n                head_config.name: PolicyHead(\n                    in_features=config.embedding.embedding_size,\n                    config=head_config,\n                    defaults=config.defaults,\n                    shared_embedding=self.policy_embedding_shared,\n                    rngs=rngs,\n                )\n                for head_config in config.policy_head\n            }\n        )\n        self.movesleft_heads = nnx.Dict(\n            {\n                head_config.name: MovesLeftHead(\n                    in_features=config.embedding.embedding_size,\n                    config=head_config,\n                    defaults=config.defaults,\n                    rngs=rngs,\n                )\n                for head_config in config.movesleft_head\n            }\n        )\n\n    def __call__(self, x: jax.Array) -> ModelPrediction:\n        x = jnp.astype(x, get_dtype(self.config.defaults.compute_dtype))\n        x = jnp.transpose(x, (1, 2, 0))\n        x = jnp.reshape(x, (64, self._input_channels))\n        x = self.embedding(x)\n        x = self.encoders(x)\n\n        value = {name: head(x) for name, head in self.value_heads.items()}\n        policy = {name: head(x) for name, head in self.policy_heads.items()}\n        movesleft = {\n            name: head(x) for name, head in self.movesleft_heads.items()\n        }\n\n        return ModelPrediction(value=value, policy=policy, movesleft=movesleft)\n"
  },
  {
    "path": "src/lczero_training/model/movesleft_head.py",
    "content": "import jax\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .utils import get_activation\n\n\nclass MovesLeftHead(nnx.Module):\n    def __init__(\n        self,\n        in_features: int,\n        config: model_config_pb2.MovesLeftHeadConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        self.activation = defaults.activation\n        self.embed = nnx.Linear(\n            in_features=in_features,\n            out_features=config.num_channels,\n            rngs=rngs,\n        )\n\n        self.dense1 = nnx.Linear(\n            in_features=config.num_channels * 64,\n            out_features=128,\n            rngs=rngs,\n        )\n        self.out = nnx.Linear(\n            in_features=128,\n            out_features=1,\n            rngs=rngs,\n        )\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        x = self.embed(x).flatten()\n        x = get_activation(self.activation)(x)\n        x = self.dense1(x)\n        x = get_activation(self.activation)(x)\n        x = self.out(x)\n        return nnx.relu(x)\n"
  },
  {
    "path": "src/lczero_training/model/policy_head.py",
    "content": "import math\nfrom typing import Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .utils import get_activation  # , get_policy_map\n\n\nclass PolicyHead(nnx.Module):\n    def __init__(\n        self,\n        in_features: int,\n        config: model_config_pb2.PolicyHeadConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        shared_embedding: Optional[nnx.Linear] = None,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        assert (shared_embedding is not None) != config.HasField(\n            \"embedding_size\"\n        )\n        self.activation = defaults.activation\n        if shared_embedding is not None:\n            self.tokens = shared_embedding\n            embedding_size = shared_embedding.out_features\n        else:\n            self.tokens = nnx.Linear(\n                in_features=in_features,\n                out_features=config.embedding_size,\n                rngs=rngs,\n            )\n            embedding_size = config.embedding_size\n\n        self.q = nnx.Linear(\n            in_features=embedding_size,\n            out_features=config.d_model,\n            rngs=rngs,\n        )\n\n        self.k = nnx.Linear(\n            in_features=embedding_size,\n            out_features=config.d_model,\n            rngs=rngs,\n        )\n\n        self.dk = math.sqrt(config.d_model)\n        self.promotion_dense = nnx.Linear(\n            in_features=config.d_model,\n            out_features=4,\n            use_bias=False,\n            rngs=rngs,\n        )\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        x = self.tokens(x)\n        x = get_activation(self.activation)(x)\n\n        q = self.q(x)\n        k = self.k(x)\n        qk = jnp.einsum(\"qd,kd->qk\", q, k)\n\n        promotion_keys = k[-8:, :]\n        promotion_offsets = self.promotion_dense(promotion_keys)\n        promotion_offsets = promotion_offsets.transpose((1, 0)) * self.dk\n        # knight offset is added to the other three\n        promotion_offsets = promotion_offsets[:3, :] + promotion_offsets[3:4, :]\n\n        n_promo_logits = qk[-16:-8, -8:]\n        q_promo_logits = jnp.expand_dims(\n            n_promo_logits + promotion_offsets[0:1, :], axis=-1\n        )\n        r_promo_logits = jnp.expand_dims(\n            n_promo_logits + promotion_offsets[1:2, :], axis=-1\n        )\n        b_promo_logits = jnp.expand_dims(\n            n_promo_logits + promotion_offsets[2:3, :], axis=-1\n        )\n        promotion_logits = jnp.concatenate(\n            [q_promo_logits, r_promo_logits, b_promo_logits], axis=-1\n        )\n        policy_attn_logits = qk / self.dk\n        promotion_logits = promotion_logits.reshape((8, 24)) / self.dk\n\n        logits = jnp.concatenate(\n            [policy_attn_logits.flatten(), promotion_logits.flatten()], axis=-1\n        )\n\n        return logits[_policy_map]\n\n\n# fmt: off\n_policy_map = jnp.array([\n 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17, 18, 24, 27, 32, 36, 40, 45, 48, 54, 56,\n 63, 64, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 80, 81, 82, 83, 89, 92, 97,\n 101, 105, 110, 113, 119, 121, 128, 129, 131, 132, 133, 134, 135, 136, 137, 138,\n 139, 140, 144, 145, 146, 147, 148, 154, 157, 162, 166, 170, 175, 178, 186, 192,\n 193, 194, 196, 197, 198, 199, 201, 202, 203, 204, 205, 209, 210, 211, 212, 213,\n 216, 219, 222, 227, 231, 235, 243, 251, 256, 257, 258, 259, 261, 262, 263, 266,\n 267, 268, 269, 270, 274, 275, 276, 277, 278, 281, 284, 287, 288, 292, 300, 308,\n 316, 320, 321, 322, 323, 324, 326, 327, 331, 332, 333, 334, 335, 339, 340, 341,\n 342, 343, 346, 349, 353, 357, 360, 365, 373, 381, 384, 385, 386, 387, 388, 389,\n 391, 396, 397, 398, 399, 404, 405, 406, 407, 411, 414, 418, 422, 425, 430, 432,\n 438, 446, 448, 449, 450, 451, 452, 453, 454, 461, 462, 463, 469, 470, 471, 476,\n 479, 483, 487, 490, 495, 497, 503, 504, 511, 512, 513, 514, 521, 522, 523, 524,\n 525, 526, 527, 528, 529, 530, 536, 537, 538, 544, 547, 552, 556, 560, 565, 568,\n 574, 576, 577, 578, 579, 584, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595,\n 600, 601, 602, 603, 609, 612, 617, 621, 625, 630, 633, 639, 640, 641, 642, 643,\n 644, 648, 649, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 664, 665, 666,\n 667, 668, 674, 677, 682, 686, 690, 695, 698, 705, 706, 707, 708, 709, 712, 713,\n 714, 716, 717, 718, 719, 721, 722, 723, 724, 725, 729, 730, 731, 732, 733, 736,\n 739, 742, 747, 751, 755, 763, 770, 771, 772, 773, 774, 776, 777, 778, 779, 781,\n 782, 783, 786, 787, 788, 789, 790, 794, 795, 796, 797, 798, 801, 804, 807, 808,\n 812, 820, 828, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 846, 847, 851,\n 852, 853, 854, 855, 859, 860, 861, 862, 863, 866, 869, 873, 877, 880, 885, 893,\n 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 911, 916, 917, 918, 919, 924,\n 925, 926, 927, 931, 934, 938, 942, 945, 950, 952, 958, 965, 966, 967, 968, 969,\n 970, 971, 972, 973, 974, 981, 982, 983, 989, 990, 991, 996, 999, 1003, 1007,\n 1010, 1015, 1017, 1023, 1024, 1025, 1026, 1032, 1033, 1034, 1041, 1042, 1043,\n 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1056, 1057, 1058, 1064, 1067, 1072,\n 1076, 1080, 1085, 1088, 1089, 1090, 1091, 1096, 1097, 1098, 1099, 1104, 1106,\n 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1120, 1121, 1122, 1123,\n 1129, 1132, 1137, 1141, 1145, 1150, 1152, 1153, 1154, 1155, 1156, 1160, 1161,\n 1162, 1163, 1164, 1168, 1169, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178,\n 1179, 1180, 1184, 1185, 1186, 1187, 1188, 1194, 1197, 1202, 1206, 1210, 1215,\n 1217, 1218, 1219, 1220, 1221, 1225, 1226, 1227, 1228, 1229, 1232, 1233, 1234,\n 1236, 1237, 1238, 1239, 1241, 1242, 1243, 1244, 1245, 1249, 1250, 1251, 1252,\n 1253, 1256, 1259, 1262, 1267, 1271, 1275, 1282, 1283, 1284, 1285, 1286, 1290,\n 1291, 1292, 1293, 1294, 1296, 1297, 1298, 1299, 1301, 1302, 1303, 1306, 1307,\n 1308, 1309, 1310, 1314, 1315, 1316, 1317, 1318, 1321, 1324, 1327, 1328, 1332,\n 1340, 1347, 1348, 1349, 1350, 1351, 1355, 1356, 1357, 1358, 1359, 1360, 1361,\n 1362, 1363, 1364, 1366, 1367, 1371, 1372, 1373, 1374, 1375, 1379, 1380, 1381,\n 1382, 1383, 1386, 1389, 1393, 1397, 1400, 1405, 1412, 1413, 1414, 1415, 1420,\n 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1431, 1436, 1437, 1438,\n 1439, 1444, 1445, 1446, 1447, 1451, 1454, 1458, 1462, 1465, 1470, 1477, 1478,\n 1479, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1501, 1502,\n 1503, 1509, 1510, 1511, 1516, 1519, 1523, 1527, 1530, 1535, 1536, 1539, 1544,\n 1545, 1546, 1552, 1553, 1554, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568,\n 1569, 1570, 1576, 1577, 1578, 1584, 1587, 1592, 1596, 1601, 1604, 1608, 1609,\n 1610, 1611, 1616, 1617, 1618, 1619, 1624, 1626, 1627, 1628, 1629, 1630, 1631,\n 1632, 1633, 1634, 1635, 1640, 1641, 1642, 1643, 1649, 1652, 1657, 1661, 1666,\n 1669, 1672, 1673, 1674, 1675, 1676, 1680, 1681, 1682, 1683, 1684, 1688, 1689,\n 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1704, 1705, 1706,\n 1707, 1708, 1714, 1717, 1722, 1726, 1728, 1731, 1734, 1737, 1738, 1739, 1740,\n 1741, 1745, 1746, 1747, 1748, 1749, 1752, 1753, 1754, 1756, 1757, 1758, 1759,\n 1761, 1762, 1763, 1764, 1765, 1769, 1770, 1771, 1772, 1773, 1776, 1779, 1782,\n 1787, 1791, 1793, 1796, 1799, 1802, 1803, 1804, 1805, 1806, 1810, 1811, 1812,\n 1813, 1814, 1816, 1817, 1818, 1819, 1821, 1822, 1823, 1826, 1827, 1828, 1829,\n 1830, 1834, 1835, 1836, 1837, 1838, 1841, 1844, 1847, 1848, 1852, 1858, 1861,\n 1867, 1868, 1869, 1870, 1871, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882,\n 1883, 1884, 1886, 1887, 1891, 1892, 1893, 1894, 1895, 1899, 1900, 1901, 1902,\n 1903, 1906, 1909, 1913, 1917, 1923, 1926, 1932, 1933, 1934, 1935, 1940, 1941,\n 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1951, 1956, 1957, 1958, 1959,\n 1964, 1965, 1966, 1967, 1971, 1974, 1978, 1982, 1988, 1991, 1997, 1998, 1999,\n 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2021, 2022, 2023,\n 2029, 2030, 2031, 2036, 2039, 2043, 2047, 2048, 2052, 2056, 2059, 2064, 2065,\n 2066, 2072, 2073, 2074, 2081, 2082, 2083, 2084, 2085, 2086, 2087, 2088, 2089,\n 2090, 2096, 2097, 2098, 2104, 2107, 2113, 2117, 2121, 2124, 2128, 2129, 2130,\n 2131, 2136, 2137, 2138, 2139, 2144, 2146, 2147, 2148, 2149, 2150, 2151, 2152,\n 2153, 2154, 2155, 2160, 2161, 2162, 2163, 2169, 2172, 2178, 2182, 2186, 2189,\n 2192, 2193, 2194, 2195, 2196, 2200, 2201, 2202, 2203, 2204, 2208, 2209, 2211,\n 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2224, 2225, 2226, 2227,\n 2228, 2234, 2237, 2243, 2247, 2248, 2251, 2254, 2257, 2258, 2259, 2260, 2261,\n 2265, 2266, 2267, 2268, 2269, 2272, 2273, 2274, 2276, 2277, 2278, 2279, 2281,\n 2282, 2283, 2284, 2285, 2289, 2290, 2291, 2292, 2293, 2296, 2299, 2302, 2304,\n 2308, 2313, 2316, 2319, 2322, 2323, 2324, 2325, 2326, 2330, 2331, 2332, 2333,\n 2334, 2336, 2337, 2338, 2339, 2341, 2342, 2343, 2346, 2347, 2348, 2349, 2350,\n 2354, 2355, 2356, 2357, 2358, 2361, 2364, 2367, 2369, 2373, 2378, 2381, 2387,\n 2388, 2389, 2390, 2391, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403,\n 2404, 2406, 2407, 2411, 2412, 2413, 2414, 2415, 2419, 2420, 2421, 2422, 2423,\n 2426, 2429, 2434, 2438, 2443, 2446, 2452, 2453, 2454, 2455, 2460, 2461, 2462,\n 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2471, 2476, 2477, 2478, 2479, 2484,\n 2485, 2486, 2487, 2491, 2494, 2499, 2503, 2508, 2511, 2517, 2518, 2519, 2525,\n 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2541, 2542, 2543, 2549,\n 2550, 2551, 2556, 2559, 2560, 2565, 2568, 2572, 2576, 2579, 2584, 2585, 2586,\n 2592, 2593, 2594, 2601, 2602, 2603, 2604, 2605, 2606, 2607, 2608, 2609, 2610,\n 2616, 2617, 2618, 2625, 2630, 2633, 2637, 2641, 2644, 2648, 2649, 2650, 2651,\n 2656, 2657, 2658, 2659, 2664, 2666, 2667, 2668, 2669, 2670, 2671, 2672, 2673,\n 2674, 2675, 2680, 2681, 2682, 2683, 2690, 2695, 2698, 2702, 2706, 2709, 2712,\n 2713, 2714, 2715, 2716, 2720, 2721, 2722, 2723, 2724, 2728, 2729, 2731, 2732,\n 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740, 2744, 2745, 2746, 2747, 2748,\n 2755, 2763, 2767, 2768, 2771, 2774, 2777, 2778, 2779, 2780, 2781, 2785, 2786,\n 2787, 2788, 2789, 2792, 2793, 2794, 2796, 2797, 2798, 2799, 2801, 2802, 2803,\n 2804, 2805, 2809, 2810, 2811, 2812, 2813, 2820, 2824, 2828, 2833, 2836, 2839,\n 2842, 2843, 2844, 2845, 2846, 2850, 2851, 2852, 2853, 2854, 2856, 2857, 2858,\n 2859, 2861, 2862, 2863, 2866, 2867, 2868, 2869, 2870, 2874, 2875, 2876, 2877,\n 2878, 2880, 2885, 2889, 2893, 2898, 2901, 2907, 2908, 2909, 2910, 2911, 2915,\n 2916, 2917, 2918, 2919, 2920, 2921, 2922, 2923, 2924, 2926, 2927, 2931, 2932,\n 2933, 2934, 2935, 2939, 2940, 2941, 2942, 2943, 2945, 2950, 2954, 2958, 2963,\n 2966, 2972, 2973, 2974, 2975, 2980, 2981, 2982, 2983, 2984, 2985, 2986, 2987,\n 2988, 2989, 2991, 2996, 2997, 2998, 2999, 3004, 3005, 3006, 3007, 3010, 3015,\n 3019, 3023, 3028, 3031, 3037, 3038, 3039, 3045, 3046, 3047, 3048, 3049, 3050,\n 3051, 3052, 3053, 3054, 3061, 3062, 3063, 3069, 3070, 3071, 3072, 3078, 3080,\n 3085, 3088, 3092, 3096, 3099, 3104, 3105, 3106, 3112, 3113, 3114, 3121, 3122,\n 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3137, 3143, 3145, 3150, 3153,\n 3157, 3161, 3164, 3168, 3169, 3170, 3171, 3176, 3177, 3178, 3179, 3184, 3186,\n 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3202, 3210, 3215, 3218,\n 3222, 3226, 3229, 3232, 3233, 3234, 3235, 3236, 3240, 3241, 3242, 3243, 3244,\n 3248, 3249, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3267,\n 3275, 3283, 3287, 3288, 3291, 3294, 3297, 3298, 3299, 3300, 3301, 3305, 3306,\n 3307, 3308, 3309, 3312, 3313, 3314, 3316, 3317, 3318, 3319, 3321, 3322, 3323,\n 3324, 3325, 3332, 3340, 3344, 3348, 3353, 3356, 3359, 3362, 3363, 3364, 3365,\n 3366, 3370, 3371, 3372, 3373, 3374, 3376, 3377, 3378, 3379, 3381, 3382, 3383,\n 3386, 3387, 3388, 3389, 3390, 3397, 3400, 3405, 3409, 3413, 3418, 3421, 3427,\n 3428, 3429, 3430, 3431, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443,\n 3444, 3446, 3447, 3451, 3452, 3453, 3454, 3455, 3456, 3462, 3465, 3470, 3474,\n 3478, 3483, 3486, 3492, 3493, 3494, 3495, 3500, 3501, 3502, 3503, 3504, 3505,\n 3506, 3507, 3508, 3509, 3511, 3516, 3517, 3518, 3519, 3521, 3527, 3530, 3535,\n 3539, 3543, 3548, 3551, 3557, 3558, 3559, 3565, 3566, 3567, 3568, 3569, 3570,\n 3571, 3572, 3573, 3574, 3581, 3582, 3583, 3584, 3591, 3592, 3598, 3600, 3605,\n 3608, 3612, 3616, 3619, 3624, 3625, 3626, 3632, 3633, 3634, 3641, 3642, 3643,\n 3644, 3645, 3646, 3647, 3649, 3657, 3663, 3665, 3670, 3673, 3677, 3681, 3684,\n 3688, 3689, 3690, 3691, 3696, 3697, 3698, 3699, 3704, 3706, 3707, 3708, 3709,\n 3710, 3711, 3714, 3722, 3730, 3735, 3738, 3742, 3746, 3749, 3752, 3753, 3754,\n 3755, 3756, 3760, 3761, 3762, 3763, 3764, 3768, 3769, 3771, 3772, 3773, 3774,\n 3775, 3779, 3787, 3795, 3803, 3807, 3808, 3811, 3814, 3817, 3818, 3819, 3820,\n 3821, 3825, 3826, 3827, 3828, 3829, 3832, 3833, 3834, 3836, 3837, 3838, 3839,\n 3844, 3852, 3860, 3864, 3868, 3873, 3876, 3879, 3882, 3883, 3884, 3885, 3886,\n 3890, 3891, 3892, 3893, 3894, 3896, 3897, 3898, 3899, 3901, 3902, 3903, 3909,\n 3917, 3920, 3925, 3929, 3933, 3938, 3941, 3947, 3948, 3949, 3950, 3951, 3955,\n 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3966, 3967, 3974, 3976,\n 3982, 3985, 3990, 3994, 3998, 4003, 4006, 4012, 4013, 4014, 4015, 4020, 4021,\n 4022, 4023, 4024, 4025, 4026, 4027, 4028, 4029, 4031, 4032, 4039, 4041, 4047,\n 4050, 4055, 4059, 4063, 4068, 4071, 4077, 4078, 4079, 4085, 4086, 4087, 4088,\n 4089, 4090, 4091, 4092, 4093, 4094, 4096, 4097, 4098, 4099, 4100, 4101, 4120,\n 4121, 4122, 4123, 4124, 4125, 4126, 4127, 4128, 4147, 4148, 4149, 4150, 4151,\n 4152, 4153, 4154, 4155, 4174, 4175, 4176, 4177, 4178, 4179, 4180, 4181, 4182,\n 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4228, 4229, 4230, 4231,\n 4232, 4233, 4234, 4235, 4236, 4255, 4256, 4257, 4258, 4259, 4260, 4261, 4262,\n 4263, 4282, 4283, 4284, 4285, 4286, 4287,\n], jnp.int32)\n"
  },
  {
    "path": "src/lczero_training/model/shared.py",
    "content": "import jax\nfrom flax import nnx\nfrom flax.linen import initializers as flax_initializers\n\nfrom proto import net_pb2\n\nfrom .utils import get_activation\n\n\nclass Ffn(nnx.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: int,\n        hidden_activation: net_pb2.NetworkFormat.ActivationFunction,\n        deepnorm_beta: float,\n        *,\n        rngs: nnx.Rngs,\n    ):\n        deepnorm_init = flax_initializers.variance_scaling(\n            scale=deepnorm_beta,\n            mode=\"fan_avg\",\n            distribution=\"truncated_normal\",\n        )\n        out_features = in_features\n        self.linear1 = nnx.Linear(\n            in_features=in_features,\n            out_features=hidden_features,\n            kernel_init=deepnorm_init,\n            rngs=rngs,\n        )\n        self.activation = hidden_activation\n        self.linear2 = nnx.Linear(\n            in_features=hidden_features,\n            out_features=out_features,\n            kernel_init=deepnorm_init,\n            rngs=rngs,\n        )\n\n    def __call__(self, x: jax.Array) -> jax.Array:\n        x = self.linear1(x)\n        x = get_activation(self.activation)(x)\n        x = self.linear2(x)\n        return x\n"
  },
  {
    "path": "src/lczero_training/model/utils.py",
    "content": "from typing import Any\n\nimport jax.numpy as jnp\nfrom flax import nnx\nfrom jax.nn import mish\n\nfrom proto import net_pb2\nfrom proto.hlo_pb2 import XlaShapeProto\n\n\ndef get_activation(\n    activation: net_pb2.NetworkFormat.ActivationFunction,\n) -> Any:\n    return {\n        net_pb2.NetworkFormat.ACTIVATION_MISH: mish,\n        net_pb2.NetworkFormat.ACTIVATION_RELU: nnx.relu,\n        net_pb2.NetworkFormat.ACTIVATION_NONE: lambda x: x,\n        net_pb2.NetworkFormat.ACTIVATION_TANH: nnx.tanh,\n        net_pb2.NetworkFormat.ACTIVATION_SIGMOID: nnx.sigmoid,\n        net_pb2.NetworkFormat.ACTIVATION_SELU: nnx.selu,\n        net_pb2.NetworkFormat.ACTIVATION_SWISH: nnx.swish,\n        net_pb2.NetworkFormat.ACTIVATION_SOFTMAX: nnx.softmax,\n    }[activation]\n\n\ndef get_dtype(dtype: XlaShapeProto.Type) -> jnp.dtype:\n    return {\n        XlaShapeProto.PRED: jnp.bool_,\n        XlaShapeProto.S4: jnp.int4,\n        XlaShapeProto.S8: jnp.int8,\n        XlaShapeProto.S16: jnp.int16,\n        XlaShapeProto.S32: jnp.int32,\n        XlaShapeProto.S64: jnp.int64,\n        XlaShapeProto.U4: jnp.uint4,\n        XlaShapeProto.U8: jnp.uint8,\n        XlaShapeProto.U16: jnp.uint16,\n        XlaShapeProto.U32: jnp.uint32,\n        XlaShapeProto.U64: jnp.uint64,\n        XlaShapeProto.F16: jnp.float16,\n        XlaShapeProto.F32: jnp.float32,\n        XlaShapeProto.BF16: jnp.bfloat16,\n        XlaShapeProto.F64: jnp.float64,\n        XlaShapeProto.F8E5M2: jnp.float8_e5m2,\n        XlaShapeProto.F8E4M3FN: jnp.float8_e4m3fn,\n        XlaShapeProto.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,\n        XlaShapeProto.F8E5M2FNUZ: jnp.float8_e5m2fnuz,\n        XlaShapeProto.F8E4M3FNUZ: jnp.float8_e4m3fnuz,\n        XlaShapeProto.C64: jnp.complex64,\n        XlaShapeProto.C128: jnp.complex128,\n    }[dtype]\n"
  },
  {
    "path": "src/lczero_training/model/value_head.py",
    "content": "from typing import Optional, Tuple\n\nimport jax\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .utils import get_activation\n\n\nclass ValueHead(nnx.Module):\n    def __init__(\n        self,\n        in_features: int,\n        config: model_config_pb2.ValueHeadConfig,\n        defaults: model_config_pb2.DefaultsConfig,\n        rngs: nnx.Rngs,\n    ):\n        self.activation = defaults.activation\n        self.has_error_output = config.has_error_output\n        self.num_categorical_buckets = config.num_categorical_buckets\n        self.embed = nnx.Linear(\n            in_features=in_features,\n            out_features=config.num_channels,\n            rngs=rngs,\n        )\n        self.dense1 = nnx.Linear(\n            in_features=config.num_channels * 64,\n            out_features=128,\n            rngs=rngs,\n        )\n        self.wdl = nnx.Linear(\n            in_features=128,\n            out_features=3,\n            rngs=rngs,\n        )\n        if self.has_error_output:\n            self.error = nnx.Linear(\n                in_features=128,\n                out_features=1,\n                rngs=rngs,\n            )\n        if self.num_categorical_buckets > 0:\n            self.categorical = nnx.Linear(\n                in_features=128,\n                out_features=self.num_categorical_buckets,\n                rngs=rngs,\n            )\n\n    def __call__(\n        self, x: jax.Array\n    ) -> Tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:\n        x = self.embed(x).flatten()\n        x = get_activation(self.activation)(x)\n        x = self.dense1(x)\n        x = get_activation(self.activation)(x)\n\n        wdl = self.wdl(x)\n        error = nnx.sigmoid(self.error(x)) if self.has_error_output else None\n        categorical = (\n            self.categorical(x) if self.num_categorical_buckets > 0 else None\n        )\n\n        return (wdl, error, categorical)\n\n    def predict(self, x: jax.Array) -> jax.Array:\n        return nnx.softmax(self(x)[0])\n"
  },
  {
    "path": "src/lczero_training/py.typed",
    "content": ""
  },
  {
    "path": "src/lczero_training/tests/test_protobuf.py",
    "content": "\"\"\"Test protobuf compilation and functionality.\"\"\"\n\n\ndef test_protobuf_import() -> None:\n    \"\"\"Test that protobuf files can be imported.\"\"\"\n    import proto.data_loader_config_pb2 as data_loader_config_pb2\n    import proto.model_config_pb2 as model_config_pb2\n    import proto.root_config_pb2 as root_config_pb2\n    import proto.training_config_pb2 as training_config_pb2\n    import proto.training_metrics_pb2 as training_metrics_pb2\n\n    # Test creating config objects\n    data_loader_config = data_loader_config_pb2.DataLoaderConfig()\n    assert data_loader_config is not None\n\n    root_config = root_config_pb2.RootConfig()\n    assert root_config is not None\n\n    training_config = training_config_pb2.TrainingConfig()\n    assert training_config is not None\n\n    model_config = model_config_pb2.ModelConfig()\n    assert model_config is not None\n\n    metrics = training_metrics_pb2.DataLoaderMetricsProto()\n    assert metrics is not None\n\n\ndef test_protobuf_functionality() -> None:\n    \"\"\"Test basic protobuf functionality.\"\"\"\n    import proto.data_loader_config_pb2 as data_loader_config_pb2\n    import proto.root_config_pb2 as root_config_pb2\n\n    # Create a config and set some values\n    config = data_loader_config_pb2.DataLoaderConfig()\n    stage = config.stage.add()\n    stage.name = \"file_path_provider\"\n    stage.file_path_provider.directory = \"/test/path\"\n\n    # Serialize and deserialize\n    serialized = config.SerializeToString()\n    assert len(serialized) > 0\n\n    config2 = data_loader_config_pb2.DataLoaderConfig()\n    config2.ParseFromString(serialized)\n\n    assert config2.stage[0].file_path_provider.directory == \"/test/path\"\n\n    # Test RootConfig functionality\n    root_config = root_config_pb2.RootConfig()\n    root_config.name = \"test_config\"\n    stage_config = root_config.data_loader.stage.add()\n    stage_config.name = \"file_path_provider\"\n    stage_config.file_path_provider.directory = \"/test/path\"\n\n    # Serialize and deserialize root config\n    root_serialized = root_config.SerializeToString()\n    assert len(root_serialized) > 0\n\n    root_config2 = root_config_pb2.RootConfig()\n    root_config2.ParseFromString(root_serialized)\n\n    assert root_config2.name == \"test_config\"\n    assert (\n        root_config2.data_loader.stage[0].file_path_provider.directory\n        == \"/test/path\"\n    )\n"
  },
  {
    "path": "src/lczero_training/tests/test_protocol_registry.py",
    "content": "\"\"\"Test script for the protocol registry system.\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport pytest\n\nfrom lczero_training.daemon.protocol.registry import (\n    CLASS_TO_TYPE_MAP,\n    TYPE_TO_CLASS_MAP,\n    register,\n)\n\n\n@pytest.fixture(autouse=True)\ndef clear_registry() -> Any:\n    \"\"\"Clear registry maps before each test.\"\"\"\n    TYPE_TO_CLASS_MAP.clear()\n    CLASS_TO_TYPE_MAP.clear()\n    yield\n    TYPE_TO_CLASS_MAP.clear()\n    CLASS_TO_TYPE_MAP.clear()\n\n\ndef test_basic_registration() -> None:\n    \"\"\"Test basic event type registration.\"\"\"\n\n    @register(\"test_event\")\n    @dataclass\n    class BasicPayload:\n        content: str\n\n    # Check forward mapping\n    assert TYPE_TO_CLASS_MAP[\"test_event\"] == BasicPayload\n    # Check reverse mapping\n    assert CLASS_TO_TYPE_MAP[BasicPayload] == \"test_event\"\n\n\ndef test_duplicate_event_type() -> None:\n    \"\"\"Test that duplicate event types are rejected.\"\"\"\n\n    @register(\"duplicate_event\")\n    @dataclass\n    class FirstPayload:\n        data: str\n\n    with pytest.raises(\n        ValueError, match=r\".*duplicate_event.*already registered.*\"\n    ):\n\n        @register(\"duplicate_event\")  # Should fail\n        @dataclass\n        class SecondPayload:\n            other_data: int\n\n\ndef test_duplicate_class() -> None:\n    \"\"\"Test that duplicate classes are rejected.\"\"\"\n\n    @dataclass\n    class PayloadClass:\n        data: str\n\n    # Register once\n    register(\"first_event\")(PayloadClass)\n\n    # Try to register same class again - should fail\n    with pytest.raises(\n        ValueError, match=r\".*PayloadClass.*already registered.*\"\n    ):\n        register(\"second_event\")(PayloadClass)\n\n\ndef test_non_class_registration() -> None:\n    \"\"\"Test that non-classes are rejected.\"\"\"\n    with pytest.raises(TypeError, match=r\".*can only be used on classes.*\"):\n        # Try to register a string instead of a class\n        @register(\"invalid_event\")  # type: ignore[arg-type]\n        def not_a_class() -> None:\n            pass\n\n\ndef test_multiple_registrations() -> None:\n    \"\"\"Test multiple valid registrations work correctly.\"\"\"\n\n    @register(\"event_one\")\n    @dataclass\n    class PayloadOne:\n        data: str\n\n    @register(\"event_two\")\n    @dataclass\n    class PayloadTwo:\n        value: int\n\n    @register(\"event_three\")\n    @dataclass\n    class PayloadThree:\n        items: list\n\n    # Check all mappings exist\n    assert TYPE_TO_CLASS_MAP[\"event_one\"] == PayloadOne\n    assert TYPE_TO_CLASS_MAP[\"event_two\"] == PayloadTwo\n    assert TYPE_TO_CLASS_MAP[\"event_three\"] == PayloadThree\n\n    assert CLASS_TO_TYPE_MAP[PayloadOne] == \"event_one\"\n    assert CLASS_TO_TYPE_MAP[PayloadTwo] == \"event_two\"\n    assert CLASS_TO_TYPE_MAP[PayloadThree] == \"event_three\"\n\n    # Check we have exactly 3 entries in each map\n    assert len(TYPE_TO_CLASS_MAP) == 3\n    assert len(CLASS_TO_TYPE_MAP) == 3\n\n\ndef test_registry_persistence() -> None:\n    \"\"\"Test that registry persists across imports.\"\"\"\n\n    @register(\"persistent_event\")\n    @dataclass\n    class PersistentPayload:\n        data: str\n\n    # Re-import the module\n    from lczero_training.daemon.protocol.registry import (\n        CLASS_TO_TYPE_MAP as imported_class_map,\n    )\n    from lczero_training.daemon.protocol.registry import (\n        TYPE_TO_CLASS_MAP as imported_type_map,\n    )\n\n    # Check the registration persists\n    assert imported_type_map[\"persistent_event\"] == PersistentPayload\n    assert imported_class_map[PersistentPayload] == \"persistent_event\"\n"
  },
  {
    "path": "src/lczero_training/tests/test_weights_tool.py",
    "content": "\"\"\"Test weights tool arithmetic operations.\"\"\"\n\nimport os\nimport tempfile\n\nimport numpy as np\n\nimport proto.net_pb2 as net_pb2\nfrom lczero_training.tools.weight_wrappers import NetWrapper\nfrom lczero_training.tools.weights_tool import load_weights, save_weights\n\n\ndef test_weights_arithmetic() -> None:\n    \"\"\"Test arithmetic operations on simple networks.\"\"\"\n    # Create network A with a single layer containing value 10.0.\n    net_a = net_pb2.Net()\n    net_a.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_a.weights.ip1_val_w.min_val = 10.0\n    net_a.weights.ip1_val_w.max_val = 10.0\n    # LINEAR16 encoding: value 10.0 maps to uint16 value 32767 (mid-point).\n    net_a.weights.ip1_val_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Create network B with a single layer containing value 20.0.\n    net_b = net_pb2.Net()\n    net_b.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_b.weights.ip1_val_w.min_val = 20.0\n    net_b.weights.ip1_val_w.max_val = 20.0\n    net_b.weights.ip1_val_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Wrap the networks.\n    wrapper_a = NetWrapper(net_a)\n    wrapper_b = NetWrapper(net_b)\n\n    # Compute: output = 0.2*A + 0.8*B\n    # Expected: 0.2*10 + 0.8*20 = 2 + 16 = 18\n    output = 0.2 * wrapper_a + 0.8 * wrapper_b\n\n    # Check the result.\n    result_value = output.weights.ip1_val_w.value\n    expected = 18.0\n\n    assert result_value.shape == (1,), (\n        f\"Expected shape (1,), got {result_value.shape}\"\n    )\n    assert np.isclose(result_value[0], expected, rtol=1e-4), (\n        f\"Expected {expected}, got {result_value[0]}\"\n    )\n\n\ndef test_policy_head_replacement() -> None:\n    \"\"\"Test that assigning policy_heads actually replaces the data.\"\"\"\n    # Create network A with policy head value 10.0.\n    net_a = net_pb2.Net()\n    net_a.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_a.weights.policy_heads.ip_pol_w.min_val = 10.0\n    net_a.weights.policy_heads.ip_pol_w.max_val = 10.0\n    net_a.weights.policy_heads.ip_pol_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Create network B with policy head value 20.0.\n    net_b = net_pb2.Net()\n    net_b.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_b.weights.policy_heads.ip_pol_w.min_val = 20.0\n    net_b.weights.policy_heads.ip_pol_w.max_val = 20.0\n    net_b.weights.policy_heads.ip_pol_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Wrap and perform assignment.\n    wrapper_a = NetWrapper(net_a)\n    wrapper_b = NetWrapper(net_b)\n\n    # Replace A's policy heads with B's.\n    wrapper_a.weights.policy_heads = wrapper_b.weights.policy_heads\n\n    # Verify in-memory replacement.\n    result_value = wrapper_a.weights.policy_heads.ip_pol_w.value\n    assert np.isclose(result_value[0], 20.0, rtol=1e-4), (\n        f\"Expected 20.0 (B's value), got {result_value[0]}\"\n    )\n\n    # Verify persistence (round-trip).\n    with tempfile.NamedTemporaryFile(suffix=\".pb.gz\", delete=False) as tmp:\n        tmp_path = tmp.name\n    try:\n        save_weights(wrapper_a, tmp_path)\n        reloaded = load_weights(tmp_path)\n        reloaded_value = reloaded.weights.policy_heads.ip_pol_w.value\n        assert np.isclose(reloaded_value[0], 20.0, rtol=1e-4), (\n            f\"After save/load: Expected 20.0, got {reloaded_value[0]}\"\n        )\n    finally:\n        os.unlink(tmp_path)\n\n\ndef test_policy_head_map_assignment() -> None:\n    \"\"\"Test assigning policy_heads when one has policy_head_map and another doesn't.\"\"\"\n    # Create network A with NO policy_head_map.\n    net_a = net_pb2.Net()\n    net_a.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_a.weights.policy_heads.ip_pol_w.min_val = 5.0\n    net_a.weights.policy_heads.ip_pol_w.max_val = 5.0\n    net_a.weights.policy_heads.ip_pol_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Create network B WITH policy_head_map.\n    net_b = net_pb2.Net()\n    net_b.format.weights_encoding = net_pb2.Format.LINEAR16\n    # Add a policy_head_map entry with required key and value fields.\n    policy_map = net_b.weights.policy_heads.policy_head_map.add()\n    policy_map.key = \"test_policy\"\n    policy_map.value.ip_pol_w.min_val = 15.0\n    policy_map.value.ip_pol_w.max_val = 15.0\n    policy_map.value.ip_pol_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    # Wrap networks.\n    wrapper_a = NetWrapper(net_a)\n    wrapper_b = NetWrapper(net_b)\n\n    # Replace A's policy heads with B's (which has policy_head_map).\n    wrapper_a.weights.policy_heads = wrapper_b.weights.policy_heads\n\n    # Access policy_head_map through the cache to verify consistency.\n    assert len(wrapper_a.weights.policy_heads.policy_head_map) == 1\n    assert (\n        wrapper_a.weights.policy_heads.policy_head_map[0]._proto.key\n        == \"test_policy\"\n    )\n\n    # Verify can save without serialization errors.\n    with tempfile.NamedTemporaryFile(suffix=\".pb.gz\", delete=False) as tmp:\n        tmp_path = tmp.name\n    try:\n        save_weights(wrapper_a, tmp_path)\n        # Verify round-trip preserves the policy_head_map.\n        reloaded = load_weights(tmp_path)\n        assert len(reloaded.weights.policy_heads.policy_head_map) == 1\n        assert (\n            reloaded.weights.policy_heads.policy_head_map[0]._proto.key\n            == \"test_policy\"\n        )\n        policy_value = reloaded.weights.policy_heads.policy_head_map[\n            0\n        ].value.ip_pol_w.value\n        assert np.isclose(policy_value[0], 15.0, rtol=1e-4)\n    finally:\n        os.unlink(tmp_path)\n\n\ndef test_noop_arithmetic() -> None:\n    \"\"\"Test that 0.3*A + 0.7*A equals A (no-op operation).\"\"\"\n    # Create network A with simple weights (value 10.0).\n    net_a = net_pb2.Net()\n    net_a.format.weights_encoding = net_pb2.Format.LINEAR16\n    net_a.weights.ip1_val_w.min_val = 10.0\n    net_a.weights.ip1_val_w.max_val = 10.0\n    net_a.weights.ip1_val_w.params = np.array(\n        [32767], dtype=np.uint16\n    ).tobytes()\n\n    wrapper_a = NetWrapper(net_a)\n\n    # Perform no-op operation: 0.3*A + 0.7*A should equal A.\n    result = 0.3 * wrapper_a + 0.7 * wrapper_a\n\n    # Verify in-memory: result should equal A (10.0).\n    result_value = result.weights.ip1_val_w.value\n    expected = 10.0\n    assert np.isclose(result_value[0], expected, rtol=1e-4), (\n        f\"Expected {expected}, got {result_value[0]}\"\n    )\n\n    # Verify persistence: save and reload.\n    with tempfile.NamedTemporaryFile(suffix=\".pb.gz\", delete=False) as tmp:\n        tmp_path = tmp.name\n    try:\n        save_weights(result, tmp_path)\n        reloaded = load_weights(tmp_path)\n        reloaded_value = reloaded.weights.ip1_val_w.value\n        assert np.isclose(reloaded_value[0], expected, rtol=1e-4), (\n            f\"After save/load: Expected {expected}, got {reloaded_value[0]}\"\n        )\n    finally:\n        os.unlink(tmp_path)\n\n\ndef test_list_item_assignment() -> None:\n    \"\"\"Test that list item assignment works for encoder blocks.\"\"\"\n    # Create network A with encoder block containing value 10.0.\n    net_a = net_pb2.Net()\n    net_a.format.weights_encoding = net_pb2.Format.LINEAR16\n    encoder_a = net_a.weights.encoder.add()\n    encoder_a.mha.q_w.min_val = 10.0\n    encoder_a.mha.q_w.max_val = 10.0\n    encoder_a.mha.q_w.params = np.array([32767], dtype=np.uint16).tobytes()\n\n    # Create network B with encoder block containing value 20.0.\n    net_b = net_pb2.Net()\n    net_b.format.weights_encoding = net_pb2.Format.LINEAR16\n    encoder_b = net_b.weights.encoder.add()\n    encoder_b.mha.q_w.min_val = 20.0\n    encoder_b.mha.q_w.max_val = 20.0\n    encoder_b.mha.q_w.params = np.array([32767], dtype=np.uint16).tobytes()\n\n    # Wrap networks.\n    wrapper_a = NetWrapper(net_a)\n    wrapper_b = NetWrapper(net_b)\n\n    # Assign encoder[0] from B to A.\n    wrapper_a.weights.encoder[0] = wrapper_b.weights.encoder[0]\n\n    # Verify in-memory: A's encoder[0] should now have value 20.0.\n    result_value = wrapper_a.weights.encoder[0].mha.q_w.value\n    assert np.isclose(result_value[0], 20.0, rtol=1e-4), (\n        f\"Expected 20.0 (B's value), got {result_value[0]}\"\n    )\n\n    # Verify persistence.\n    with tempfile.NamedTemporaryFile(suffix=\".pb.gz\", delete=False) as tmp:\n        tmp_path = tmp.name\n    try:\n        save_weights(wrapper_a, tmp_path)\n        reloaded = load_weights(tmp_path)\n        reloaded_value = reloaded.weights.encoder[0].mha.q_w.value\n        assert np.isclose(reloaded_value[0], 20.0, rtol=1e-4), (\n            f\"After save/load: Expected 20.0, got {reloaded_value[0]}\"\n        )\n    finally:\n        os.unlink(tmp_path)\n"
  },
  {
    "path": "src/lczero_training/tools/__init__.py",
    "content": "\"\"\"Pure Python tools for weight manipulation.\"\"\"\n\nfrom .weights_tool import load_weights, save_weights\n\n__all__ = [\"load_weights\", \"save_weights\"]\n"
  },
  {
    "path": "src/lczero_training/tools/weight_codecs.py",
    "content": "\"\"\"Encoding and decoding logic for Lc0 weight formats.\"\"\"\n\nimport numpy as np\n\nfrom proto import net_pb2\n\n\ndef decode_linear16(\n    params: bytes, min_val: float, max_val: float, shape: tuple[int, ...]\n) -> np.ndarray:\n    \"\"\"Decode LINEAR16 format to float32 array.\"\"\"\n    raw = np.frombuffer(params, dtype=np.uint16)\n    norm = raw.astype(np.float32) / 65535.0\n    values = norm * max_val + (1.0 - norm) * min_val\n    return values.reshape(shape[::-1]).transpose()\n\n\ndef encode_linear16(arr: np.ndarray) -> tuple[bytes, float, float]:\n    \"\"\"Encode float32 array to LINEAR16 format.\"\"\"\n    flat = arr.T.flatten().astype(np.float32)\n    min_val = float(flat.min())\n    max_val = float(flat.max())\n    rng = max_val - min_val\n\n    if rng < 1e-8:\n        norm = np.full_like(flat, 0.5)\n    else:\n        norm = (flat - min_val) / rng\n\n    quant = np.round(norm * 65535.0).astype(np.uint16)\n    return quant.tobytes(), min_val, max_val\n\n\ndef decode_float16(params: bytes, shape: tuple[int, ...]) -> np.ndarray:\n    \"\"\"Decode FLOAT16 format to float32 array.\"\"\"\n    raw = np.frombuffer(params, dtype=np.float16)\n    values = raw.astype(np.float32)\n    return values.reshape(shape[::-1]).transpose()\n\n\ndef encode_float16(arr: np.ndarray) -> tuple[bytes, float, float]:\n    \"\"\"Encode float32 array to FLOAT16 format.\"\"\"\n    flat = arr.T.flatten().astype(np.float16)\n    return flat.tobytes(), 0.0, 0.0\n\n\ndef decode_bfloat16(params: bytes, shape: tuple[int, ...]) -> np.ndarray:\n    \"\"\"Decode BFLOAT16 format to float32 array via bit manipulation.\"\"\"\n    raw_u16 = np.frombuffer(params, dtype=np.uint16)\n    raw_u32 = raw_u16.astype(np.uint32) << 16\n    values = raw_u32.view(np.float32)\n    return values.reshape(shape[::-1]).transpose()\n\n\ndef encode_bfloat16(arr: np.ndarray) -> tuple[bytes, float, float]:\n    \"\"\"Encode float32 array to BFLOAT16 format via bit manipulation.\"\"\"\n    flat = arr.T.flatten().astype(np.float32)\n    u32 = flat.view(np.uint32)\n    u16 = (u32 >> 16).astype(np.uint16)\n    return u16.tobytes(), 0.0, 0.0\n\n\ndef decode_layer(\n    layer: net_pb2.Weights.Layer, fallback_encoding: int\n) -> np.ndarray:\n    \"\"\"Decode a Layer protobuf to float32 NumPy array.\"\"\"\n    encoding = layer.encoding if layer.encoding else fallback_encoding\n\n    if not layer.dims:\n        size = len(layer.params) // 2\n        shape: tuple[int, ...] = (size,)\n    else:\n        shape = tuple(layer.dims)\n\n    if encoding == net_pb2.Weights.Layer.LINEAR16:\n        return decode_linear16(\n            layer.params, layer.min_val, layer.max_val, shape\n        )\n    elif encoding == net_pb2.Weights.Layer.FLOAT16:\n        return decode_float16(layer.params, shape)\n    elif encoding == net_pb2.Weights.Layer.BFLOAT16:\n        return decode_bfloat16(layer.params, shape)\n    else:\n        raise ValueError(f\"Unknown encoding: {encoding}\")\n\n\ndef encode_layer(arr: np.ndarray, encoding: int) -> tuple[bytes, float, float]:\n    \"\"\"Encode a float32 NumPy array to Layer format.\"\"\"\n    if encoding == net_pb2.Weights.Layer.LINEAR16:\n        return encode_linear16(arr)\n    elif encoding == net_pb2.Weights.Layer.FLOAT16:\n        return encode_float16(arr)\n    elif encoding == net_pb2.Weights.Layer.BFLOAT16:\n        return encode_bfloat16(arr)\n    else:\n        raise ValueError(f\"Unknown encoding: {encoding}\")\n"
  },
  {
    "path": "src/lczero_training/tools/weight_wrappers.py",
    "content": "\"\"\"Wrapper classes for pythonic access to Lc0 weight protobufs.\"\"\"\n\nimport gzip\nfrom typing import Any, Iterator\n\nimport numpy as np\nfrom google.protobuf.message import Message\n\nfrom proto import net_pb2\n\nfrom . import weight_codecs\n\n\nclass LayerWrapper:\n    \"\"\"Wraps a net_pb2.Weights.Layer with lazy float32 decoding.\"\"\"\n\n    __slots__ = (\"_proto\", \"_fallback_encoding\", \"_cached_array\", \"_modified\")\n\n    def __init__(\n        self, proto: net_pb2.Weights.Layer, fallback_encoding: int\n    ) -> None:\n        object.__setattr__(self, \"_proto\", proto)\n        object.__setattr__(self, \"_fallback_encoding\", fallback_encoding)\n        object.__setattr__(self, \"_cached_array\", None)\n        object.__setattr__(self, \"_modified\", False)\n\n    _proto: net_pb2.Weights.Layer\n    _fallback_encoding: int\n    _cached_array: np.ndarray | None\n    _modified: bool\n\n    @property\n    def value(self) -> np.ndarray:\n        \"\"\"Decode to float32 on first access.\"\"\"\n        if self._cached_array is None:\n            decoded = weight_codecs.decode_layer(\n                self._proto, self._fallback_encoding\n            )\n            object.__setattr__(self, \"_cached_array\", decoded)\n        assert self._cached_array is not None\n        return self._cached_array\n\n    @value.setter\n    def value(self, arr: np.ndarray) -> None:\n        \"\"\"Set new array value, mark as modified.\"\"\"\n        object.__setattr__(self, \"_cached_array\", arr.astype(np.float32))\n        object.__setattr__(self, \"_modified\", True)\n\n    def commit(self, encoding: int) -> None:\n        \"\"\"Re-encode array to proto if modified.\"\"\"\n        if self._modified and self._cached_array is not None:\n            params, min_val, max_val = weight_codecs.encode_layer(\n                self._cached_array, encoding\n            )\n            self._proto.params = params\n            self._proto.min_val = min_val\n            self._proto.max_val = max_val\n            self._proto.encoding = encoding  # type: ignore[assignment]\n            del self._proto.dims[:]\n            self._proto.dims.extend(self._cached_array.shape)\n            object.__setattr__(self, \"_modified\", False)\n\n    def __add__(self, other: \"LayerWrapper\") -> \"LayerWrapper\":\n        if not isinstance(other, LayerWrapper):\n            raise TypeError(\"Can only add LayerWrapper to LayerWrapper\")\n        result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding)\n        result.value = self.value + other.value\n        return result\n\n    def __sub__(self, other: \"LayerWrapper\") -> \"LayerWrapper\":\n        if not isinstance(other, LayerWrapper):\n            raise TypeError(\"Can only subtract LayerWrapper from LayerWrapper\")\n        result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding)\n        result.value = self.value - other.value\n        return result\n\n    def __mul__(self, scalar: float) -> \"LayerWrapper\":\n        if not isinstance(scalar, (int, float)):\n            raise TypeError(\"Can only multiply LayerWrapper by scalar\")\n        result = LayerWrapper(net_pb2.Weights.Layer(), self._fallback_encoding)\n        result.value = self.value * scalar\n        return result\n\n    def __rmul__(self, scalar: float) -> \"LayerWrapper\":\n        return self.__mul__(scalar)\n\n\nclass ListWrapper:\n    \"\"\"Wraps protobuf repeated fields.\"\"\"\n\n    __slots__ = (\"_proto_list\", \"_parent\", \"_item_cache\")\n\n    def __init__(self, proto_list: Any, parent: \"NetWrapper\") -> None:\n        self._proto_list = proto_list\n        self._parent = parent\n        self._item_cache: dict[int, Any] = {}\n\n    _proto_list: Any\n    _parent: \"NetWrapper\"\n    _item_cache: dict[int, Any]\n\n    def __len__(self) -> int:\n        return len(self._proto_list)\n\n    def __getitem__(self, idx: int) -> Any:\n        if idx not in self._item_cache:\n            item_proto = self._proto_list[idx]\n            self._item_cache[idx] = self._parent._wrap_field(item_proto)\n        return self._item_cache[idx]\n\n    def __setitem__(self, idx: int, value: Any) -> None:\n        \"\"\"Support item assignment for list elements.\"\"\"\n        if isinstance(value, NetWrapper):\n            dest_proto = self._proto_list[idx]\n            dest_proto.CopyFrom(value._proto)\n            # Create new wrapper for destination proto to maintain cache consistency.\n            self._item_cache[idx] = NetWrapper(\n                dest_proto, self._parent._fallback_encoding\n            )\n        elif isinstance(value, LayerWrapper):\n            dest_proto = self._proto_list[idx]\n            dest_proto.CopyFrom(value._proto)\n            # Create new wrapper for destination proto to maintain cache consistency.\n            self._item_cache[idx] = LayerWrapper(\n                dest_proto, self._parent._fallback_encoding\n            )\n        else:\n            self._proto_list[idx] = value\n            if idx in self._item_cache:\n                del self._item_cache[idx]\n\n    def __iter__(self) -> Iterator[Any]:\n        for i in range(len(self)):\n            yield self[i]\n\n\nclass NetWrapper:\n    \"\"\"Wraps net_pb2.Net or nested Message types.\"\"\"\n\n    __slots__ = (\"_proto\", \"_fallback_encoding\", \"_attr_cache\")\n\n    def __init__(\n        self, proto_msg: Message, fallback_encoding: int | None = None\n    ) -> None:\n        object.__setattr__(self, \"_proto\", proto_msg)\n        object.__setattr__(\n            self,\n            \"_fallback_encoding\",\n            fallback_encoding or self._detect_encoding(),\n        )\n        object.__setattr__(self, \"_attr_cache\", {})\n\n    _proto: Message\n    _fallback_encoding: int\n    _attr_cache: dict[str, Any]\n\n    def _detect_encoding(self) -> int:\n        \"\"\"Extract encoding from net.format.weights_encoding.\"\"\"\n        if hasattr(self._proto, \"format\") and self._proto.HasField(\"format\"):\n            return self._proto.format.weights_encoding\n        return net_pb2.Weights.Layer.LINEAR16\n\n    def __getattr__(self, name: str) -> Any:\n        if name.startswith(\"_\"):\n            return object.__getattribute__(self, name)\n\n        if name in self._attr_cache:\n            return self._attr_cache[name]\n\n        if not hasattr(self._proto, name):\n            raise AttributeError(\n                f\"{type(self._proto).__name__} has no field '{name}'\"\n            )\n\n        value = getattr(self._proto, name)\n        wrapped = self._wrap_field(value)\n        self._attr_cache[name] = wrapped\n        return wrapped\n\n    def _wrap_field(self, value: Any) -> Any:\n        \"\"\"Determine wrapper type based on proto field.\"\"\"\n        if isinstance(value, net_pb2.Weights.Layer):\n            return LayerWrapper(value, self._fallback_encoding)\n        elif isinstance(value, Message):\n            return NetWrapper(value, self._fallback_encoding)\n        elif hasattr(value, \"__len__\") and not isinstance(value, (str, bytes)):\n            return ListWrapper(value, self)\n        else:\n            return value\n\n    def __setattr__(self, name: str, value: Any) -> None:\n        if name.startswith(\"_\"):\n            object.__setattr__(self, name, value)\n            return\n\n        if isinstance(value, NetWrapper):\n            dest_proto = getattr(self._proto, name)\n            dest_proto.CopyFrom(value._proto)\n            # Create new wrapper for destination proto to maintain cache consistency.\n            self._attr_cache[name] = NetWrapper(\n                dest_proto, self._fallback_encoding\n            )\n        elif isinstance(value, LayerWrapper):\n            dest_proto = getattr(self._proto, name)\n            dest_proto.CopyFrom(value._proto)\n            # Create new wrapper for destination proto to maintain cache consistency.\n            self._attr_cache[name] = LayerWrapper(\n                dest_proto, self._fallback_encoding\n            )\n        else:\n            setattr(self._proto, name, value)\n\n    def save(self, path: str, encoding: int | None = None) -> None:\n        \"\"\"Save to file, committing all modified layers.\"\"\"\n        if encoding is None:\n            encoding = net_pb2.Weights.Layer.FLOAT16\n\n        self._commit_all(encoding)\n\n        serialized = self._proto.SerializePartialToString()\n        if path.endswith(\".gz\"):\n            with gzip.open(path, \"wb\") as f:\n                f.write(serialized)\n        else:\n            with open(path, \"wb\") as f:\n                f.write(serialized)\n\n    def _commit_all(self, encoding: int) -> None:\n        \"\"\"Recursively commit all modified LayerWrappers.\"\"\"\n        for cached_value in self._attr_cache.values():\n            if isinstance(cached_value, LayerWrapper):\n                cached_value.commit(encoding)\n            elif isinstance(cached_value, NetWrapper):\n                cached_value._commit_all(encoding)\n            elif isinstance(cached_value, ListWrapper):\n                for item in cached_value:\n                    if isinstance(item, NetWrapper):\n                        item._commit_all(encoding)\n                    elif isinstance(item, LayerWrapper):\n                        item.commit(encoding)\n\n    def __add__(self, other: \"NetWrapper\") -> \"NetWrapper\":\n        \"\"\"Element-wise addition of two networks.\"\"\"\n        if not isinstance(other, NetWrapper):\n            raise TypeError(\"Can only add NetWrapper to NetWrapper\")\n\n        result_proto = type(self._proto)()\n        result_proto.CopyFrom(self._proto)\n        result = NetWrapper(result_proto, self._fallback_encoding)\n        result._add_weights(self, other)\n        return result\n\n    def __sub__(self, other: \"NetWrapper\") -> \"NetWrapper\":\n        \"\"\"Element-wise subtraction of two networks.\"\"\"\n        if not isinstance(other, NetWrapper):\n            raise TypeError(\"Can only subtract NetWrapper from NetWrapper\")\n\n        result_proto = type(self._proto)()\n        result_proto.CopyFrom(self._proto)\n        result = NetWrapper(result_proto, self._fallback_encoding)\n        result._sub_weights(self, other)\n        return result\n\n    def __mul__(self, scalar: float) -> \"NetWrapper\":\n        \"\"\"Scalar multiplication.\"\"\"\n        if not isinstance(scalar, (int, float)):\n            raise TypeError(\"Can only multiply NetWrapper by scalar\")\n\n        result_proto = type(self._proto)()\n        result_proto.CopyFrom(self._proto)\n        result = NetWrapper(result_proto, self._fallback_encoding)\n        result._mul_weights(self, scalar)\n        return result\n\n    def __rmul__(self, scalar: float) -> \"NetWrapper\":\n        return self.__mul__(scalar)\n\n    def _add_weights(self, lhs: \"NetWrapper\", rhs: \"NetWrapper\") -> None:\n        \"\"\"Recursively add weights from lhs and rhs into self.\"\"\"\n        for field_desc in lhs._proto.DESCRIPTOR.fields:\n            field_name = field_desc.name\n            if not hasattr(lhs._proto, field_name):\n                continue\n\n            # Skip optional fields that are not set in both inputs.\n            if not field_desc.is_required and not field_desc.is_repeated:\n                lhs_has = lhs._proto.HasField(field_name)\n                rhs_has = rhs._proto.HasField(field_name)\n                if not lhs_has or not rhs_has:\n                    continue\n\n            lhs_val = getattr(lhs, field_name)\n            rhs_val = getattr(rhs, field_name)\n\n            if isinstance(lhs_val, LayerWrapper):\n                self_layer = getattr(self, field_name)\n                self_layer.value = lhs_val.value + rhs_val.value\n            elif isinstance(lhs_val, NetWrapper):\n                self_wrapper = getattr(self, field_name)\n                self_wrapper._add_weights(lhs_val, rhs_val)\n            elif isinstance(lhs_val, ListWrapper):\n                self_list = getattr(self, field_name)\n                # Only process indices that exist in both lists.\n                min_len = min(len(lhs_val), len(rhs_val))\n                for i in range(min_len):\n                    if isinstance(lhs_val[i], (NetWrapper, LayerWrapper)):\n                        if isinstance(lhs_val[i], LayerWrapper):\n                            self_list[i].value = (\n                                lhs_val[i].value + rhs_val[i].value\n                            )\n                        else:\n                            self_list[i]._add_weights(lhs_val[i], rhs_val[i])\n                # Truncate to min_len to avoid incomplete entries.\n                del self_list._proto_list[min_len:]\n\n    def _sub_weights(self, lhs: \"NetWrapper\", rhs: \"NetWrapper\") -> None:\n        \"\"\"Recursively subtract weights rhs from lhs into self.\"\"\"\n        for field_desc in lhs._proto.DESCRIPTOR.fields:\n            field_name = field_desc.name\n            if not hasattr(lhs._proto, field_name):\n                continue\n\n            # Skip optional fields that are not set in both inputs.\n            if not field_desc.is_required and not field_desc.is_repeated:\n                lhs_has = lhs._proto.HasField(field_name)\n                rhs_has = rhs._proto.HasField(field_name)\n                if not lhs_has or not rhs_has:\n                    continue\n\n            lhs_val = getattr(lhs, field_name)\n            rhs_val = getattr(rhs, field_name)\n\n            if isinstance(lhs_val, LayerWrapper):\n                self_layer = getattr(self, field_name)\n                self_layer.value = lhs_val.value - rhs_val.value\n            elif isinstance(lhs_val, NetWrapper):\n                self_wrapper = getattr(self, field_name)\n                self_wrapper._sub_weights(lhs_val, rhs_val)\n            elif isinstance(lhs_val, ListWrapper):\n                self_list = getattr(self, field_name)\n                # Only process indices that exist in both lists.\n                min_len = min(len(lhs_val), len(rhs_val))\n                for i in range(min_len):\n                    if isinstance(lhs_val[i], (NetWrapper, LayerWrapper)):\n                        if isinstance(lhs_val[i], LayerWrapper):\n                            self_list[i].value = (\n                                lhs_val[i].value - rhs_val[i].value\n                            )\n                        else:\n                            self_list[i]._sub_weights(lhs_val[i], rhs_val[i])\n                # Truncate to min_len to avoid incomplete entries.\n                del self_list._proto_list[min_len:]\n\n    def _mul_weights(self, source: \"NetWrapper\", scalar: float) -> None:\n        \"\"\"Recursively multiply all weights by scalar.\"\"\"\n        for field_desc in source._proto.DESCRIPTOR.fields:\n            field_name = field_desc.name\n            if not hasattr(source._proto, field_name):\n                continue\n\n            # Skip unset optional fields to avoid creating them in output.\n            if (\n                not field_desc.is_required\n                and not field_desc.is_repeated\n                and not source._proto.HasField(field_name)\n            ):\n                continue\n\n            src_val = getattr(source, field_name)\n\n            if isinstance(src_val, LayerWrapper):\n                self_layer = getattr(self, field_name)\n                self_layer.value = src_val.value * scalar\n            elif isinstance(src_val, NetWrapper):\n                self_wrapper = getattr(self, field_name)\n                self_wrapper._mul_weights(src_val, scalar)\n            elif isinstance(src_val, ListWrapper):\n                self_list = getattr(self, field_name)\n                for i in range(len(src_val)):\n                    if isinstance(src_val[i], (NetWrapper, LayerWrapper)):\n                        if isinstance(src_val[i], LayerWrapper):\n                            self_list[i].value = src_val[i].value * scalar\n                        else:\n                            self_list[i]._mul_weights(src_val[i], scalar)\n"
  },
  {
    "path": "src/lczero_training/tools/weights_tool.py",
    "content": "\"\"\"Main API for loading and saving Lc0 weight files.\"\"\"\n\nimport gzip\n\nfrom proto import net_pb2\n\nfrom .weight_wrappers import NetWrapper\n\n\ndef load_weights(path: str) -> NetWrapper:\n    \"\"\"Load Lc0 weights file (.pb or .pb.gz).\"\"\"\n    if path.endswith(\".gz\"):\n        with gzip.open(path, \"rb\") as f:\n            contents = f.read()\n    else:\n        with open(path, \"rb\") as f:\n            contents = f.read()\n\n    net = net_pb2.Net()\n    net.ParseFromString(contents)\n    return NetWrapper(net)\n\n\ndef save_weights(\n    wrapper: NetWrapper, path: str, encoding: str = \"FLOAT16\"\n) -> None:\n    \"\"\"Save weights to file.\"\"\"\n    encoding_map = {\n        \"LINEAR16\": net_pb2.Weights.Layer.LINEAR16,\n        \"FLOAT16\": net_pb2.Weights.Layer.FLOAT16,\n        \"BFLOAT16\": net_pb2.Weights.Layer.BFLOAT16,\n    }\n    enc_value = encoding_map[encoding.upper()]\n    wrapper.save(path, enc_value)\n"
  },
  {
    "path": "src/lczero_training/training/__init__.py",
    "content": "\"\"\"Training package for Leela Chess Zero.\"\"\"\n"
  },
  {
    "path": "src/lczero_training/training/backfill_metrics.py",
    "content": "\"\"\"Backfill metrics for existing checkpoints.\"\"\"\n\nimport logging\nfrom typing import Any\n\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training.daemon.metrics import (\n    evaluate_batch,\n    load_batch_from_npz,\n)\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.migrate_checkpoint import (\n    Migration,\n    get_checkpoint_steps,\n    load_checkpoint,\n    load_migration_rules,\n)\nfrom lczero_training.training.state import TrainingState\nfrom lczero_training.training.tensorboard import TensorboardLogger\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef _load_config(config_path: str) -> RootConfig:\n    \"\"\"Load RootConfig from textproto file.\"\"\"\n    config = RootConfig()\n    with open(config_path, \"r\") as f:\n        text_format.Parse(f.read(), config)\n    return config\n\n\ndef _validate_and_get_metrics(\n    root_config: RootConfig, metric_names: list[str]\n) -> dict[str, Any]:\n    \"\"\"Validate metrics exist and are NPZ type.\"\"\"\n    if not root_config.HasField(\"metrics\"):\n        raise ValueError(\"No metrics configuration found in root config\")\n\n    metric_configs = {\n        mc.name: mc\n        for mc in root_config.metrics.metric\n        if mc.name in metric_names and mc.HasField(\"npz_filename\")\n    }\n\n    non_npz = [\n        mc.name\n        for mc in root_config.metrics.metric\n        if mc.name in metric_names and not mc.HasField(\"npz_filename\")\n    ]\n    if non_npz:\n        raise ValueError(f\"Non-NPZ metrics: {', '.join(non_npz)}\")\n\n    missing = set(metric_names) - set(metric_configs.keys())\n    if missing:\n        raise ValueError(f\"Metrics not found: {', '.join(sorted(missing))}\")\n\n    return metric_configs\n\n\ndef _load_and_migrate_checkpoint(\n    checkpoint_path: str,\n    step: int,\n    template: TrainingState,\n    rules: list[tuple],\n) -> TrainingState:\n    \"\"\"Load checkpoint and apply migration if rules provided.\"\"\"\n    state, _ = load_checkpoint(checkpoint_path, step)\n    return Migration(state, template).run(rules)\n\n\ndef backfill_metrics(\n    config_path: str,\n    metric_names: list[str],\n    min_step: int | None = None,\n    max_step: int | None = None,\n    migration_config_path: str | None = None,\n) -> None:\n    \"\"\"Backfill metrics for existing checkpoints.\n\n    Args:\n        config_path: Path to the RootConfig textproto file.\n        metric_names: Names of metrics to backfill (must be NPZ metrics).\n        min_step: Minimum checkpoint step (inclusive), or None.\n        max_step: Maximum checkpoint step (inclusive), or None.\n        migration_config_path: Path to CheckpointMigrationConfig file, or None.\n\n    Raises:\n        ValueError: If any metric is not an NPZ metric or doesn't exist.\n    \"\"\"\n    config = _load_config(config_path)\n    metric_configs = _validate_and_get_metrics(config, metric_names)\n\n    # Load batches and create loggers.\n    batches = {\n        name: load_batch_from_npz(mc.npz_filename)\n        for name, mc in metric_configs.items()\n    }\n    loggers = {\n        name: TensorboardLogger(f\"{config.metrics.tensorboard_path}/{name}\")\n        for name in metric_configs\n    }\n\n    # Initialize model components.\n    loss_fn = LczeroLoss(config=config.training.losses)\n    graphdef, _ = nnx.split(\n        LczeroModel(config=config.model, rngs=nnx.Rngs(params=42))\n    )\n\n    # Prepare migration if needed.\n    rules = load_migration_rules(migration_config_path)\n    template = TrainingState.new_from_config(config.model, config.training)\n\n    # Get and process checkpoints.\n    steps = get_checkpoint_steps(\n        config.training.checkpoint.path, min_step, max_step\n    )\n    logger.info(f\"Processing {len(steps)} checkpoints\")\n\n    if not steps:\n        logger.warning(\"No checkpoints found in range\")\n        return\n\n    try:\n        for step in steps:\n            logger.info(f\"Step {step}\")\n            state = _load_and_migrate_checkpoint(\n                config.training.checkpoint.path, step, template, rules\n            )\n\n            for name, mc in metric_configs.items():\n                metrics = evaluate_batch(\n                    batches[name],\n                    state.jit_state,\n                    graphdef,\n                    loss_fn,\n                    mc.use_swa_model,\n                )\n                loggers[name].log(step, metrics)\n                logger.info(f\"  {name}: loss={metrics['loss']:.6f}\")\n    finally:\n        for tb_logger in loggers.values():\n            tb_logger.close()\n\n    logger.info(\"Backfill complete\")\n"
  },
  {
    "path": "src/lczero_training/training/dataloader_probe.py",
    "content": "\"\"\"Utilities for exercising the training data loader.\"\"\"\n\nimport logging\nimport time\nfrom contextlib import suppress\nfrom pathlib import Path\nfrom typing import Optional\n\nimport numpy as np\nfrom google.protobuf import text_format\n\nfrom lczero_training.dataloader import DataLoader, make_dataloader\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef _stop_loader(loader: DataLoader) -> None:\n    with suppress(Exception):\n        loader.stop()\n\n\ndef _store_batches(path: str, batches: list) -> None:\n    output = Path(path)\n    if output.parent:\n        output.parent.mkdir(parents=True, exist_ok=True)\n    logger.info(\"Writing %d batches to %s\", len(batches), output)\n    container = np.empty(len(batches), dtype=object)\n    container[:] = batches\n    np.savez(output, batches=container)\n\n\ndef probe_dataloader(\n    config_filename: str, num_batches: int, npz_output: Optional[str] = None\n) -> None:\n    \"\"\"Measure latency and throughput for the configured data loader.\n\n    Args:\n        config_filename: Path to the root configuration proto file.\n        num_batches: Total number of batches to fetch from the loader.\n        npz_output: Optional path to store fetched batches as an .npz archive.\n    \"\"\"\n\n    if num_batches < 1:\n        raise ValueError(\"num_batches must be at least 1\")\n\n    config = RootConfig()\n    logger.info(\"Reading configuration from proto file\")\n    with open(config_filename, \"r\") as config_file:\n        text_format.Parse(config_file.read(), config)\n\n    logger.info(\"Creating data loader\")\n    loader = make_dataloader(config.data_loader)\n\n    collected_batches: list = []\n    collect_enabled = npz_output is not None\n    first_batch_time = 0.0\n    remaining_batches = num_batches - 1\n    try:\n        logger.info(\"Fetching first batch\")\n        start_time = time.perf_counter()\n        first_batch = loader.get_next()\n        if collect_enabled:\n            collected_batches.append(first_batch)\n        first_batch_time = time.perf_counter() - start_time\n        logger.info(\"Time to first batch: %.3f seconds\", first_batch_time)\n\n        if remaining_batches <= 0:\n            logger.info(\"Only fetched first batch; skipping throughput\")\n            return\n\n        logger.info(\n            \"Fetching %d additional batches for throughput measurement\",\n            remaining_batches,\n        )\n        throughput_start = time.perf_counter()\n        for _ in range(remaining_batches):\n            batch = loader.get_next()\n            if collect_enabled:\n                collected_batches.append(batch)\n        throughput_duration = time.perf_counter() - throughput_start\n\n        if throughput_duration <= 0:\n            logger.warning(\"Measured non-positive duration; skipping rate\")\n            return\n\n        batches_per_second = remaining_batches / throughput_duration\n        logger.info(\n            \"Throughput excluding first batch: %.2f batches/second\",\n            batches_per_second,\n        )\n        logger.info(\n            \"Total time excluding first batch: %.3f seconds\",\n            throughput_duration,\n        )\n    finally:\n        if collect_enabled and npz_output is not None and collected_batches:\n            _store_batches(npz_output, collected_batches)\n        _stop_loader(loader)\n"
  },
  {
    "path": "src/lczero_training/training/describe.py",
    "content": "import logging\nimport sys\nfrom pathlib import PurePosixPath\n\nimport jax\nimport jax.numpy as jnp\nimport orbax.checkpoint as ocp\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training.training.state import TrainingState\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef describe(\n    config_filename: str,\n    shapes: bool = False,\n    values: bool = False,\n    weight_paths: bool = False,\n) -> None:\n    config = RootConfig()\n    logger.info(\"Reading configuration from proto file\")\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    if config.training.checkpoint.path is None:\n        logger.error(\"Checkpoint path must be set in the configuration.\")\n        sys.exit(1)\n\n    checkpoint_mgr = ocp.CheckpointManager(\n        config.training.checkpoint.path,\n        options=ocp.CheckpointManagerOptions(\n            create=True,\n        ),\n    )\n\n    logger.info(\"Creating state from configuration\")\n    empty_state = TrainingState.new_from_config(\n        model_config=config.model,\n        training_config=config.training,\n    )\n    logger.info(\"Restoring checkpoint\")\n    training_state = checkpoint_mgr.restore(\n        None, args=ocp.args.PyTreeRestore(empty_state)\n    )\n    logger.info(\"Restored checkpoint\")\n\n    assert isinstance(training_state, TrainingState)\n\n    if values:\n        logger.info(\"Dumping training state values\")\n        print(\"Training state:\")\n        print(training_state)\n\n    if shapes:\n        logger.info(\"Extracting training state shapes\")\n        shapes = jax.tree.map(jnp.shape, training_state)\n        print(\"Training state shapes:\")\n        print(shapes)\n\n    if weight_paths:\n        paths = []\n\n        def _collect(path: tuple[object, ...], _: nnx.Variable) -> bool:\n            paths.append(str(PurePosixPath(*map(str, path))))\n            return False\n\n        nnx.map_state(_collect, training_state.jit_state.model_state)\n        for p in sorted(\n            paths,\n            key=lambda p: tuple(\n                int(c) if c.isdigit() else c for c in p.split(\"/\")\n            ),\n        ):\n            print(p)\n"
  },
  {
    "path": "src/lczero_training/training/eval.py",
    "content": "# Description: Evaluation script for comparing model outputs and calculating losses.\n#\n# This script provides functionalities to evaluate a trained model by processing\n# data samples, calculating losses, and comparing outputs against an ONNX model.\n# It supports dumping tensors and results to various formats for analysis.\n\nimport json\nimport logging\nimport math\nimport shelve\nimport sys\nfrom dataclasses import dataclass\nfrom datetime import datetime\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Generator,\n    Optional,\n    Sequence,\n    TextIO,\n    Tuple,\n    cast,\n)\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport orbax.checkpoint as ocp\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training.dataloader import (\n    DataLoader,\n    make_dataloader,\n)\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel, ModelPrediction\nfrom lczero_training.training.state import TrainingSample, TrainingState\nfrom proto import data_loader_config_pb2\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\nRELATIVE_DIFFERENCE_EPSILON = 1e-4\nHEADS = (\"wdl\", \"policy\", \"movesleft\")\n\n\n@dataclass\nclass DiffRecord:\n    \"\"\"Stores information about the difference between JAX and ONNX outputs.\"\"\"\n\n    batch: int\n    sample: int\n    index: Tuple[int, ...]\n    diff: float\n    jax_value: float\n    onnx_value: float\n\n\ndef _tensor_to_list(obj: Any) -> Any:\n    \"\"\"Recursively converts JAX/Numpy arrays to Python lists for serialization.\"\"\"\n    if hasattr(obj, \"tolist\"):\n        return obj.tolist()\n    if isinstance(obj, dict):\n        return {key: _tensor_to_list(value) for key, value in obj.items()}\n    return obj\n\n\n# --- Diff statistics helpers ---\n\n\ndef _bin_counts(values: np.ndarray) -> Dict[str, Any]:\n    \"\"\"Counts values into bins based on powers of 2.\"\"\"\n    flat = np.asarray(values).ravel()\n    zero_count = int(np.count_nonzero(flat == 0.0))\n    non_zero = flat[flat != 0.0]\n\n    bins: Dict[int, int] = {}\n    if non_zero.size > 0:\n        exponents = np.floor(np.log2(non_zero)).astype(int)\n        unique, counts = np.unique(exponents, return_counts=True)\n        bins = {int(exp): int(count) for exp, count in zip(unique, counts)}\n\n    return {\"zero\": zero_count, \"bins\": bins}\n\n\ndef _format_bound(value: float) -> str:\n    \"\"\"Formats a float for display in statistics.\"\"\"\n    if value >= 1 and math.isclose(value, round(value)):\n        return str(int(round(value)))\n    return f\"{value:.6g}\"\n\n\ndef _format_stats(stats: Dict[str, Any]) -> str:\n    \"\"\"Formats bin statistics into a readable string.\"\"\"\n    lines = [f\"  zero={stats['zero']}\"]\n    for exponent in sorted(stats[\"bins\"].keys(), reverse=True):\n        lower = 2**exponent\n        upper = 2 ** (exponent + 1)\n        lines.append(\n            f\"  [{_format_bound(lower)}; {_format_bound(upper)})=\"\n            f\"{stats['bins'][exponent]}\"\n        )\n    return \"\\n\".join(lines)\n\n\ndef _collect_diff_statistics(\n    jax_output: np.ndarray, onnx_output: np.ndarray\n) -> Tuple[np.ndarray, Dict[str, Any], Optional[Dict[str, Any]]]:\n    \"\"\"Collects absolute and relative difference statistics.\"\"\"\n    abs_diff = np.abs(jax_output - onnx_output)\n    abs_stats = _bin_counts(abs_diff)\n\n    mask = np.abs(jax_output) >= RELATIVE_DIFFERENCE_EPSILON\n    if not np.any(mask):\n        return abs_diff, abs_stats, None\n\n    rel_diff = np.zeros_like(abs_diff, dtype=np.float64)\n    np.divide(abs_diff, np.abs(jax_output), out=rel_diff, where=mask)\n    rel_stats = _bin_counts(rel_diff[mask])\n    return abs_diff, abs_stats, rel_stats\n\n\nclass Dumper:\n    \"\"\"Handles dumping of evaluation artifacts.\"\"\"\n\n    def __init__(\n        self,\n        to_stdout: bool,\n        to_file: Optional[str],\n        to_shelve: Optional[str],\n        to_json: Optional[str],\n    ):\n        self.to_stdout = to_stdout\n        self.shelve_path = to_shelve\n        self.json_path = to_json\n        self.file_handle: Optional[TextIO] = (\n            open(to_file, \"w\") if to_file else None\n        )\n\n    def dump_tensors(self, tensors: dict, prefix: str) -> None:\n        \"\"\"Dumps tensors to stdout or a text file.\"\"\"\n        if not self.to_stdout and not self.file_handle:\n            return\n\n        lines = [f\"# === {prefix} TENSORS ===\"]\n        for name, tensor in tensors.items():\n            lines.append(f\"{name} = {str(_tensor_to_list(tensor))}\")\n        lines.append(\"\")\n        output_text = \"\\n\".join(lines)\n\n        if self.to_stdout:\n            print(output_text)\n        if self.file_handle:\n            self.file_handle.write(output_text)\n            self.file_handle.flush()\n\n    def dump_structured(self, batch: dict, outputs: dict, losses: dict) -> None:\n        \"\"\"Dumps results to structured formats like JSON or shelve.\"\"\"\n        if not self.shelve_path and not self.json_path:\n            return\n\n        all_data = {\n            **_tensor_to_list(batch),\n            **_tensor_to_list(outputs),\n            **_tensor_to_list(losses),\n        }\n        key = f\"sample-{datetime.now().strftime('%Y%m%d-%H%M%S')}\"\n\n        if self.shelve_path:\n            self._dump_to_shelve(key, all_data)\n        if self.json_path:\n            self._dump_to_json(key, all_data)\n\n    def _dump_to_shelve(self, key: str, data: dict) -> None:\n        assert self.shelve_path is not None\n        with shelve.open(self.shelve_path) as db:\n            db[key] = data\n        logger.info(\"Dumped data to shelve with key: %s\", key)\n\n    def _dump_to_json(self, key: str, data: dict) -> None:\n        assert self.json_path is not None\n        try:\n            with open(self.json_path, \"r\") as f:\n                json_data = json.load(f)\n        except (FileNotFoundError, json.JSONDecodeError):\n            json_data = {}\n        json_data[key] = data\n        with open(self.json_path, \"w\") as f:\n            json.dump(json_data, f, indent=2)\n        logger.info(\"Dumped data to JSON with key: %s\", key)\n\n    def close(self) -> None:\n        if self.file_handle:\n            self.file_handle.close()\n\n\nclass OnnxComparator:\n    \"\"\"Handles comparison of JAX model outputs with ONNX model outputs.\"\"\"\n\n    def __init__(self, onnx_model_path: str):\n        try:\n            import onnxruntime as ort\n        except ImportError as exc:\n            raise RuntimeError(\n                \"onnxruntime is required for ONNX comparison.\"\n            ) from exc\n\n        self.session = ort.InferenceSession(onnx_model_path)\n        inputs = self.session.get_inputs()\n        if not inputs:\n            raise ValueError(\"ONNX model must define at least one input.\")\n        self.input_name = inputs[0].name\n        logger.info(\"Loaded ONNX model for comparison from %s\", onnx_model_path)\n\n        self.head_mapping_logged = False\n        self.worst_records: Dict[str, Optional[DiffRecord]] = {\n            head: None for head in HEADS\n        }\n        self.onnx_outputs: Dict[str, jax.Array] = {}\n\n    def compare(\n        self,\n        jax_outputs: Dict[str, jax.Array],\n        onnx_inputs_np: np.ndarray,\n        sample_index: int,\n    ) -> None:\n        \"\"\"Runs comparison for a single sample.\"\"\"\n        raw_onnx_outputs = self.session.run(\n            None, {self.input_name: onnx_inputs_np}\n        )\n        if len(raw_onnx_outputs) != 3:\n            raise ValueError(\n                \"Expected three outputs (wdl, policy, movesleft) from ONNX model.\"\n            )\n\n        jax_outputs_np = {k: np.asarray(v) for k, v in jax_outputs.items()}\n        aligned_onnx, head_indices = self._align_onnx_outputs(\n            jax_outputs_np, raw_onnx_outputs\n        )\n\n        if not self.head_mapping_logged:\n            order = \", \".join(f\"{h}=output[{head_indices[h]}]\" for h in HEADS)\n            logger.info(\"Aligned ONNX outputs to heads as: %s\", order)\n            self.head_mapping_logged = True\n\n        self.onnx_outputs = {\n            \"onnx_value_pred\": jnp.asarray(aligned_onnx[\"wdl\"]),\n            \"onnx_policy_pred\": jnp.asarray(aligned_onnx[\"policy\"]),\n            \"onnx_movesleft_pred\": jnp.asarray(aligned_onnx[\"movesleft\"]),\n        }\n\n        for head in HEADS:\n            record = self._log_diff_stats(\n                head, jax_outputs_np[head], aligned_onnx[head], sample_index\n            )\n            current = self.worst_records[head]\n            if current is None or record.diff > current.diff:\n                self.worst_records[head] = record\n\n    def log_summary(self) -> None:\n        \"\"\"Logs the worst difference found for each head.\"\"\"\n        for head in HEADS:\n            record = self.worst_records[head]\n            if record:\n                logger.info(\n                    \"Worst ONNX abs diff for %s head: \"\n                    \"batch=%d sample=%d index=%s diff=%0.6g \"\n                    \"jax=%0.6g onnx=%0.6g\",\n                    head,\n                    record.batch,\n                    record.sample,\n                    record.index,\n                    record.diff,\n                    record.jax_value,\n                    record.onnx_value,\n                )\n\n    def _log_diff_stats(\n        self,\n        head: str,\n        jax_output: np.ndarray,\n        onnx_output: np.ndarray,\n        sample_index: int,\n    ) -> DiffRecord:\n        if jax_output.shape != onnx_output.shape:\n            raise ValueError(\n                f\"Shape mismatch for {head} head: \"\n                f\"JAX {jax_output.shape} vs ONNX {onnx_output.shape}.\"\n            )\n\n        abs_diff, abs_stats, rel_stats = _collect_diff_statistics(\n            jax_output, onnx_output\n        )\n\n        logger.info(\n            \"Batch %d %s head ONNX abs diff stats:\\n%s\",\n            sample_index,\n            head,\n            _format_stats(abs_stats),\n        )\n        if rel_stats:\n            logger.info(\n                \"Batch %d %s head ONNX rel diff stats:\\n%s\",\n                sample_index,\n                head,\n                _format_stats(rel_stats),\n            )\n        else:\n            logger.info(\n                \"Batch %d %s head ONNX rel diff stats: skipped (all |jax| < %.1e)\",\n                sample_index,\n                head,\n                RELATIVE_DIFFERENCE_EPSILON,\n            )\n\n        max_loc = np.unravel_index(int(np.argmax(abs_diff)), abs_diff.shape)\n        return DiffRecord(\n            batch=sample_index,\n            sample=int(max_loc[0]) if max_loc else 0,\n            index=tuple(int(i) for i in max_loc),\n            diff=float(abs_diff[max_loc]),\n            jax_value=float(jax_output[max_loc]),\n            onnx_value=float(onnx_output[max_loc]),\n        )\n\n    def _align_onnx_outputs(\n        self,\n        jax_outputs: Dict[str, np.ndarray],\n        onnx_outputs: Sequence[np.ndarray],\n    ) -> Tuple[Dict[str, np.ndarray], Dict[str, int]]:\n        \"\"\"Matches ONNX outputs to heads regardless of ordering differences.\"\"\"\n        remaining = [(i, np.asarray(o)) for i, o in enumerate(onnx_outputs)]\n        aligned: Dict[str, np.ndarray] = {}\n        indices: Dict[str, int] = {}\n\n        def pop_match(\n            predicate: Callable[[np.ndarray], bool],\n        ) -> Optional[Tuple[int, np.ndarray]]:\n            for i, candidate in enumerate(remaining):\n                if predicate(candidate[1]):\n                    return remaining.pop(i)\n            return None\n\n        for head in HEADS:\n            shape = jax_outputs[head].shape\n            match = pop_match(lambda arr: arr.shape == shape)\n            if match:\n                idx, array = match\n                aligned[head], indices[head] = array, idx\n\n        for head in HEADS:\n            if head in aligned:\n                continue\n            size = jax_outputs[head].size\n            match = pop_match(lambda arr: arr.size == size)\n            if match:\n                idx, array = match\n                aligned[head], indices[head] = array, idx\n\n        if len(aligned) != len(HEADS):\n            rem_shapes = [arr.shape for _, arr in remaining]\n            raise ValueError(\n                \"Could not align ONNX outputs with JAX outputs. \"\n                f\"Aligned: {list(aligned.keys())}; Unmatched shapes: {rem_shapes}\"\n            )\n\n        for head in HEADS:\n            aligned[head] = self._reshape_output(\n                aligned[head], jax_outputs[head].shape, head\n            )\n        return aligned, indices\n\n    def _reshape_output(\n        self, array: np.ndarray, target_shape: Tuple[int, ...], head: str\n    ) -> np.ndarray:\n        if array.shape == target_shape:\n            return array\n        if array.size != int(np.prod(target_shape)):\n            raise ValueError(\n                f\"Cannot reshape ONNX output for {head}: source shape \"\n                f\"{array.shape}, target shape {target_shape}.\"\n            )\n        try:\n            return np.reshape(array, target_shape)\n        except ValueError as exc:\n            raise ValueError(\n                f\"Failed to reshape ONNX output for {head} to {target_shape}.\"\n            ) from exc\n\n\nclass Evaluation:\n    \"\"\"Orchestrates the model evaluation process.\"\"\"\n\n    def __init__(self, loss_fn: LczeroLoss):\n        self.loss_fn = loss_fn\n\n    def run(\n        self,\n        model: LczeroModel,\n        datagen: Generator[tuple[np.ndarray, ...], None, None],\n        num_samples: int,\n        dumper: Dumper,\n        onnx_comparator: Optional[OnnxComparator],\n        softmax_jax_wdl: bool,\n    ) -> None:\n        loss_vfn = jax.vmap(self._loss_for_grad, in_axes=(None, 0), out_axes=0)\n        model_output_vfn = jax.vmap(\n            self._model_for_output, in_axes=(None, 0), out_axes=0\n        )\n\n        for i in range(num_samples):\n            logger.info(\"Processing sample %d/%d\", i, num_samples)\n            self._process_sample(\n                model,\n                datagen,\n                i,\n                dumper,\n                onnx_comparator,\n                loss_vfn,\n                model_output_vfn,\n                softmax_jax_wdl,\n            )\n            logger.info(\"Sample %d complete\", i)\n\n    def _process_sample(\n        self,\n        model: LczeroModel,\n        datagen: Generator[tuple[np.ndarray, ...], None, None],\n        sample_idx: int,\n        dumper: Dumper,\n        onnx_comparator: Optional[OnnxComparator],\n        loss_vfn: Callable[\n            [LczeroModel, TrainingSample],\n            Tuple[jax.Array, Dict[str, jax.Array]],\n        ],\n        model_output_vfn: Callable[\n            [LczeroModel, jax.Array],\n            ModelPrediction,\n        ],\n        softmax_jax_wdl: bool,\n    ) -> None:\n        batch_tuple = next(datagen)\n        logger.info(\"Fetched batch from dataloader\")\n\n        # DataLoader now returns tuple: (inputs, probabilities, values)\n        batch = {\n            \"inputs\": cast(jax.Array, jnp.asarray(batch_tuple[0])),\n            \"probabilities\": cast(jax.Array, jnp.asarray(batch_tuple[1])),\n            \"values\": cast(jax.Array, jnp.asarray(batch_tuple[2])),\n        }\n        dumper.dump_tensors(batch, \"INPUT\")\n\n        predictions = model_output_vfn(model, cast(jax.Array, batch[\"inputs\"]))\n        value_preds = predictions.value\n        policy_preds = predictions.policy\n        movesleft_preds = predictions.movesleft\n\n        # Flatten all head outputs for dumping\n        outputs = {}\n        for name, pred_tuple in value_preds.items():\n            pred = pred_tuple[0]\n            if softmax_jax_wdl:\n                pred = jax.nn.softmax(pred, axis=-1)\n            outputs[f\"value_pred/{name}\"] = pred\n        for name, pred in policy_preds.items():\n            outputs[f\"policy_pred/{name}\"] = pred\n        for name, pred in movesleft_preds.items():\n            outputs[f\"movesleft_pred/{name}\"] = pred\n\n        if onnx_comparator:\n            # Compare only legacy heads\n            jax_outputs_for_onnx = {\n                \"wdl\": value_preds[\"winner\"][0],\n                \"policy\": policy_preds[\"vanilla\"],\n                \"movesleft\": movesleft_preds[\"main\"],\n            }\n            onnx_inputs_np = np.asarray(batch[\"inputs\"]).copy()\n            onnx_inputs_np[:, 109, ...] *= 99\n            onnx_comparator.compare(\n                jax_outputs_for_onnx, onnx_inputs_np, sample_idx\n            )\n            outputs.update(onnx_comparator.onnx_outputs)\n\n        dumper.dump_tensors(outputs, \"OUTPUT\")\n\n        # Convert batch dict to TrainingSample for loss function\n        batch_sample = TrainingSample(\n            inputs=batch[\"inputs\"],\n            probabilities=batch[\"probabilities\"],\n            values=batch[\"values\"],\n        )\n        per_sample_loss, unweighted_losses = loss_vfn(model, batch_sample)\n        losses = {\n            \"per_sample_data_loss\": per_sample_loss,\n            \"unweighted_losses\": unweighted_losses,\n        }\n        dumper.dump_tensors(losses, \"LOSSES\")\n        dumper.dump_structured(batch, outputs, losses)\n\n    def _loss_for_grad(\n        self, model_arg: LczeroModel, sample_arg: TrainingSample\n    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:\n        return self.loss_fn(model_arg, sample_arg)\n\n    @staticmethod\n    def _model_for_output(\n        model_arg: LczeroModel, inputs_arg: jax.Array\n    ) -> ModelPrediction:\n        return model_arg(inputs_arg)\n\n\ndef from_dataloader(\n    loader: DataLoader,\n) -> Generator[tuple[np.ndarray, ...], None, None]:\n    \"\"\"Infinetely yields batches from a DataLoader.\"\"\"\n    while True:\n        yield loader.get_next()\n\n\ndef _load_model_from_checkpoint(config: RootConfig) -> LczeroModel:\n    \"\"\"Loads a model from the latest checkpoint.\"\"\"\n    if not config.training.checkpoint.path:\n        logger.error(\"Checkpoint path must be set in the configuration.\")\n        sys.exit(1)\n\n    mgr = ocp.CheckpointManager(\n        config.training.checkpoint.path,\n        options=ocp.CheckpointManagerOptions(create=True),\n    )\n    state = TrainingState.new_from_config(config.model, config.training)\n    restored_state = mgr.restore(\n        mgr.latest_step(), args=ocp.args.PyTreeRestore(state)\n    )\n    logger.info(\"Restored checkpoint from %s\", config.training.checkpoint.path)\n\n    assert isinstance(restored_state, TrainingState)\n    model_graph, _ = nnx.split(\n        LczeroModel(config.model, rngs=nnx.Rngs(params=42))\n    )\n    return nnx.merge(model_graph, restored_state.jit_state.model_state)\n\n\ndef _get_dataloader_config(\n    config: RootConfig, batch_size_override: Optional[int]\n) -> data_loader_config_pb2.DataLoaderConfig:\n    \"\"\"Gets the dataloader config, overriding batch size if specified.\"\"\"\n    dl_config = config.data_loader\n    if batch_size_override is None:\n        return dl_config\n\n    for stage in dl_config.stage:\n        if stage.HasField(\"tensor_generator\"):\n            stage.tensor_generator.batch_size = batch_size_override\n            logger.info(\"Overriding batch size to %d\", batch_size_override)\n            return dl_config\n\n    raise ValueError(\n        \"tensor_generator stage is required to override batch size\"\n    )\n\n\ndef eval(\n    config_filename: str,\n    num_samples: Optional[int] = None,\n    batch_size_override: Optional[int] = None,\n    dump_to_stdout: bool = False,\n    dump_to_file: Optional[str] = None,\n    dump_to_shelve: Optional[str] = None,\n    dump_to_json: Optional[str] = None,\n    onnx_model: Optional[str] = None,\n    softmax_jax_wdl: bool = True,\n) -> None:\n    \"\"\"Main function to run the evaluation.\"\"\"\n    jnp.set_printoptions(threshold=sys.maxsize, suppress=False)\n\n    config = RootConfig()\n    logger.info(\"Reading configuration from: %s\", config_filename)\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    model = _load_model_from_checkpoint(config)\n    dl_config = _get_dataloader_config(config, batch_size_override)\n    evaluation = Evaluation(loss_fn=LczeroLoss(config=config.training.losses))\n    dumper = Dumper(dump_to_stdout, dump_to_file, dump_to_shelve, dump_to_json)\n    onnx_comparator = OnnxComparator(onnx_model) if onnx_model else None\n\n    samples_to_process = num_samples if num_samples is not None else 10\n    logger.info(\"Starting evaluation with %d samples\", samples_to_process)\n\n    try:\n        evaluation.run(\n            model=model,\n            datagen=from_dataloader(make_dataloader(dl_config)),\n            num_samples=samples_to_process,\n            dumper=dumper,\n            onnx_comparator=onnx_comparator,\n            softmax_jax_wdl=softmax_jax_wdl,\n        )\n    finally:\n        dumper.close()\n        if onnx_comparator:\n            onnx_comparator.log_summary()\n\n    logger.info(\"Evaluation complete\")\n"
  },
  {
    "path": "src/lczero_training/training/init.py",
    "content": "import gzip\nimport logging\nimport os\nimport sys\nfrom typing import Optional\n\nimport orbax.checkpoint as ocp\nfrom flax import nnx\nfrom google.protobuf import text_format\n\nfrom lczero_training.convert.leela_to_jax import (\n    LeelaImportOptions,\n    fix_older_weights_file,\n    leela_to_jax,\n)\nfrom lczero_training.convert.leela_to_modelconfig import leela_to_modelconfig\nfrom lczero_training.training.state import TrainingState\nfrom proto import hlo_pb2, net_pb2\nfrom proto.model_config_pb2 import ModelConfig\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef _load_lc0_model_state(\n    path: str,\n    expected_config: ModelConfig,\n    compute_dtype: hlo_pb2.XlaShapeProto.Type,\n    ignore_config_mismatch: bool = False,\n) -> tuple[nnx.State, int]:\n    \"\"\"Load lc0 weights, validate config, return (model_state, training_steps).\"\"\"\n    lc0_weights = net_pb2.Net()\n    with gzip.open(path, \"rb\") as f:\n        lc0_weights.ParseFromString(f.read())\n    fix_older_weights_file(lc0_weights)\n\n    leela_config = leela_to_modelconfig(\n        lc0_weights, hlo_pb2.XlaShapeProto.F32, compute_dtype\n    )\n    if leela_config != expected_config:\n        if ignore_config_mismatch:\n            logger.warning(\n                \"The provided lczero model configuration \"\n                \"differs from the one in the config file (ignored).\"\n            )\n        else:\n            logger.error(\n                \"The provided lczero model configuration \"\n                \"differs from the one in the config file.\"\n            )\n            logger.error(f\"Config file model config: {expected_config}\")\n            logger.error(f\"Leela model config: {leela_config}\")\n            sys.exit(1)\n\n    import_options = LeelaImportOptions(\n        weights_dtype=hlo_pb2.XlaShapeProto.F32, compute_dtype=compute_dtype\n    )\n    model_state = leela_to_jax(lc0_weights, import_options)\n    return model_state, lc0_weights.training_params.training_steps\n\n\ndef init(\n    config_filename: str,\n    lczero_model: Optional[str],\n    seed: int = 42,\n    dry_run: bool = False,\n    swa_initial_nets: int = 0,\n    override_training_steps: Optional[int] = None,\n    overwrite: bool = False,\n    no_copy_swa: bool = False,\n    ignore_config_mismatch: bool = False,\n) -> None:\n    \"\"\"\n    Initializes a new training run.\n    \"\"\"\n\n    config = RootConfig()\n    logger.info(\"Reading configuration from proto file\")\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    checkpoint_path = config.training.checkpoint.path\n    checkpoint_exists = checkpoint_path and os.path.exists(checkpoint_path)\n\n    if not dry_run and checkpoint_exists and not overwrite:\n        logger.error(f\"Checkpoint path {checkpoint_path} already exists.\")\n        sys.exit(1)\n\n    logger.info(\"Creating initial training state from configuration\")\n    training_state = TrainingState.new_from_config(\n        model_config=config.model,\n        training_config=config.training,\n    )\n\n    if checkpoint_exists:\n        logger.info(f\"Loading from existing checkpoint: {checkpoint_path}\")\n        source_mgr = ocp.CheckpointManager(\n            checkpoint_path,\n            options=ocp.CheckpointManagerOptions(create=False),\n        )\n        training_state = source_mgr.restore(\n            source_mgr.latest_step(),\n            args=ocp.args.PyTreeRestore(training_state),\n        )\n\n    swa_enabled = config.training.HasField(\"swa\")\n\n    if lczero_model is None:\n        if override_training_steps is not None:\n            training_state = training_state.with_updated_step(\n                override_training_steps\n            )\n        if swa_enabled and swa_initial_nets > 0:\n            training_state = training_state.replace(\n                jit_state=training_state.jit_state.replace(\n                    num_averages=float(swa_initial_nets),\n                )\n            )\n    else:\n        logger.info(f\"Loading lczero model: {lczero_model}\")\n        model_state, lc0_steps = _load_lc0_model_state(\n            lczero_model,\n            config.model,\n            config.model.defaults.compute_dtype,\n            ignore_config_mismatch,\n        )\n        step = override_training_steps or lc0_steps\n        new_swa_state = (\n            training_state.jit_state.swa_state\n            if no_copy_swa\n            else (model_state if swa_enabled else None)\n        )\n        training_state = training_state.replace(\n            jit_state=training_state.jit_state.replace(\n                model_state=model_state,\n                swa_state=new_swa_state,\n                num_averages=float(swa_initial_nets) if swa_enabled else 0.0,\n            )\n        ).with_updated_step(step)\n\n    if dry_run:\n        logger.info(\n            f\"Would save checkpoint to {config.training.checkpoint.path} \"\n            f\"at step {training_state.jit_state.step}\"\n        )\n    else:\n        checkpoint_mgr = ocp.CheckpointManager(\n            config.training.checkpoint.path,\n            options=ocp.CheckpointManagerOptions(create=True),\n        )\n\n        step = training_state.jit_state.step\n        if step in checkpoint_mgr.all_steps():\n            logger.info(f\"Deleting existing checkpoint at step {step}\")\n            checkpoint_mgr.delete(step)\n            checkpoint_mgr.wait_until_finished()\n\n        logger.info(\n            f\"Saving checkpoint to {config.training.checkpoint.path} at step {step}\"\n        )\n        checkpoint_mgr.save(step=step, args=ocp.args.PyTreeSave(training_state))\n        checkpoint_mgr.wait_until_finished()\n    logger.info(\"Initialization complete.\")\n"
  },
  {
    "path": "src/lczero_training/training/lr_schedule.py",
    "content": "from typing import Callable, Sequence\n\nimport jax.numpy as jnp\nimport optax\n\nfrom proto.training_config_pb2 import LrSchedule\n\n\ndef _create_rule_fn(rule: LrSchedule) -> Callable:\n    \"\"\"\n    Creates a JAX-compatible function for a single LR schedule rule.\n\n    All data from the protobuf is extracted here and captured by the closure of\n    the returned function. This avoids protobuf parsing inside the main schedule\n    function which will be JIT-compiled.\n    \"\"\"\n    start_step = float(rule.starting_step)\n    durations = list(rule.duration_steps)\n    lrs = list(rule.lr)\n    is_looping = rule.loop\n\n    # Handle simple cases where the LR is constant for this rule.\n    if not durations or not lrs:\n        lr_val = jnp.asarray(lrs[-1] if lrs else 0.0, dtype=jnp.float32)\n        # Return a simple lambda that ignores the step and returns the constant value.\n        return lambda step: lr_val\n\n    period = sum(durations)\n    if period == 0.0:\n        lr_val = jnp.asarray(lrs[-1], dtype=jnp.float32)\n        return lambda step: lr_val\n\n    # Pre-calculate JAX arrays for use in the schedule function.\n    transitions = [\n        (\n            rule.transition[i]\n            if i < len(rule.transition)\n            else LrSchedule.Transition.CONSTANT\n        )\n        for i in range(len(durations))\n    ]\n\n    durs_j = jnp.asarray(durations, dtype=jnp.float32)\n    ends_j = jnp.cumsum(durs_j)\n    starts_j = ends_j - durs_j\n\n    # Interpolation start/end LRs for each segment.\n    a_vals = [\n        lrs[i] if i < len(lrs) else lrs[-1] for i in range(len(durations))\n    ]\n    b_vals = [\n        lrs[i + 1] if (i + 1) < len(lrs) else a_vals[i]\n        for i in range(len(durations))\n    ]\n    lrs_a_j = jnp.asarray(a_vals, dtype=jnp.float32)\n    lrs_b_j = jnp.asarray(b_vals, dtype=jnp.float32)\n\n    trans_j = jnp.asarray(transitions, dtype=jnp.int32)\n    last_lr_j = jnp.asarray(lrs[-1], dtype=jnp.float32)\n    period_j = jnp.asarray(period, dtype=jnp.float32)\n\n    def rule_fn(step: jnp.ndarray) -> jnp.ndarray:\n        \"\"\"JAX-compatible function evaluating the LR for a given step.\"\"\"\n        rel_step = step - start_step\n\n        if is_looping:\n            rel_step = jnp.mod(rel_step, period_j)\n\n        # Find active segment and calculate interpolation factor `t`.\n        is_in_segment = (\n            (rel_step >= starts_j) & (rel_step < ends_j) & (durs_j > 0)\n        )\n        # Use maximum() to avoid division by zero for zero-duration segments.\n        t = jnp.clip((rel_step - starts_j) / jnp.maximum(durs_j, 1.0), 0.0, 1.0)\n\n        # Calculate interpolated values for all segments for all transition types.\n        lin = lrs_a_j + (lrs_b_j - lrs_a_j) * t\n        cos = lrs_a_j + 0.5 * (1.0 - jnp.cos(jnp.pi * t)) * (lrs_b_j - lrs_a_j)\n\n        # Select interpolation type for each segment. Default to CONSTANT (lrs_a_j).\n        interp_vals = jnp.where(\n            trans_j == LrSchedule.Transition.LINEAR, lin, lrs_a_j\n        )\n        interp_vals = jnp.where(\n            trans_j == LrSchedule.Transition.COSINE, cos, interp_vals\n        )\n\n        # Select the value from the active segment by masking.\n        lr = jnp.sum(interp_vals * is_in_segment)\n\n        # If not in any segment (e.g., gap between segments), use the last LR value.\n        lr = jnp.where(jnp.any(is_in_segment), lr, last_lr_j)\n\n        if not is_looping:\n            # For non-looping rules, if past the end, clamp to the last LR.\n            lr = jnp.where(rel_step >= period_j, last_lr_j, lr)\n\n        return lr\n\n    return rule_fn\n\n\ndef make_lr_schedule(schedules: Sequence[LrSchedule]) -> optax.Schedule:\n    \"\"\"\n    Creates a learning rate schedule from a sequence of LrSchedule protobufs.\n\n    The schedule is composed of multiple rules, each active for a certain\n    range of training steps.\n    \"\"\"\n    if not schedules:\n        return lambda count: jnp.asarray(0.0, dtype=jnp.float32)\n\n    rule_fns = [_create_rule_fn(rule) for rule in schedules]\n    start_steps = jnp.asarray(\n        [rule.starting_step for rule in schedules], dtype=jnp.float32\n    )\n\n    # Determine the learning rate to use for steps before the first rule begins.\n    first_lrs = jnp.asarray(\n        [rule.lr[0] if rule.lr else 0.0 for rule in schedules],\n        dtype=jnp.float32,\n    )\n    earliest_rule_idx = jnp.argmin(start_steps)\n    pre_start_lr = first_lrs[earliest_rule_idx]\n    min_start_step = start_steps[earliest_rule_idx]\n\n    def schedule(count: jnp.ndarray) -> jnp.ndarray:\n        \"\"\"The actual schedule function passed to Optax.\"\"\"\n        step = jnp.asarray(count, dtype=jnp.float32)\n\n        # Find the index of the active rule. The active rule is the one with the\n        # largest starting_step that is less than or equal to the current step.\n        eligible_mask = step >= start_steps\n        # Replace non-eligible start_steps with a large negative number so they\n        # are ignored by argmax.\n        effective_starts = jnp.where(eligible_mask, start_steps, -1.0)\n        active_rule_idx = jnp.argmax(effective_starts)\n\n        all_lrs = jnp.stack([fn(step) for fn in rule_fns])\n        lr = all_lrs[active_rule_idx]\n\n        # If the current step is before any rule starts, use the pre-start LR.\n        return jnp.where(step < min_start_step, pre_start_lr, lr)\n\n    return schedule\n"
  },
  {
    "path": "src/lczero_training/training/migrate_checkpoint.py",
    "content": "from typing import Any, Dict, Iterable, List, Set, Tuple\n\nimport jax\nimport numpy as np\nimport orbax.checkpoint as ocp\nfrom flax import serialization\nfrom google.protobuf import text_format\nfrom orbax.checkpoint.utils import tuple_path_from_keypath\n\nfrom lczero_training.training import state as state_lib\nfrom proto import checkpoint_migration_config_pb2, root_config_pb2\n\n\ndef _str_to_key_path(path_str: str) -> tuple[str, ...]:\n    return tuple(path_str.split(\".\"))\n\n\ndef _load_new_state(\n    root_config: root_config_pb2.RootConfig, serialized_model: bool\n) -> Any:\n    new_state = state_lib.TrainingState.new_from_config(\n        root_config.model, root_config.training\n    )\n    if serialized_model:\n        return serialization.to_state_dict(new_state)\n    return new_state\n\n\ndef load_checkpoint(\n    checkpoint_path: str, checkpoint_step: int | None = None\n) -> Tuple[Any, int]:\n    \"\"\"Load a checkpoint from the given path.\n\n    Args:\n        checkpoint_path: Path to the checkpoint directory.\n        checkpoint_step: Step to load, or None to load the latest.\n\n    Returns:\n        Tuple of (checkpoint_state, checkpoint_step).\n    \"\"\"\n    manager = ocp.CheckpointManager(\n        checkpoint_path,\n        options=ocp.CheckpointManagerOptions(create=False),\n    )\n    if checkpoint_step is None:\n        checkpoint_step = manager.latest_step()\n    if checkpoint_step is None:\n        raise ValueError(f\"No checkpoints found in {checkpoint_path}\")\n    return manager.restore(checkpoint_step), checkpoint_step\n\n\ndef get_checkpoint_steps(\n    checkpoint_path: str,\n    min_step: int | None = None,\n    max_step: int | None = None,\n) -> list[int]:\n    \"\"\"Get all checkpoint steps in the given range.\n\n    Args:\n        checkpoint_path: Path to the checkpoint directory.\n        min_step: Minimum step (inclusive), or None for no minimum.\n        max_step: Maximum step (inclusive), or None for no maximum.\n\n    Returns:\n        List of checkpoint steps in ascending order.\n    \"\"\"\n    manager = ocp.CheckpointManager(\n        checkpoint_path,\n        options=ocp.CheckpointManagerOptions(create=False),\n    )\n    all_steps = sorted(manager.all_steps())\n    filtered_steps = []\n    for step in all_steps:\n        if min_step is not None and step < min_step:\n            continue\n        if max_step is not None and step > max_step:\n            continue\n        filtered_steps.append(step)\n    filtered_steps.sort()\n    return filtered_steps\n\n\ndef _load_old_state(\n    checkpoint_path: str, checkpoint_step: int | None\n) -> Tuple[Any, int]:\n    return load_checkpoint(checkpoint_path, checkpoint_step)\n\n\ndef load_migration_rules(rules_file: str | None) -> List[Tuple[Any, Any]]:\n    \"\"\"Load migration rules from a CheckpointMigrationConfig file.\n\n    Args:\n        rules_file: Path to the CheckpointMigrationConfig textproto file,\n            or None to return empty rules.\n\n    Returns:\n        List of (from_path, to_path) tuples representing migration rules.\n    \"\"\"\n    rules = []\n    if rules_file:\n        migration_config = (\n            checkpoint_migration_config_pb2.CheckpointMigrationConfig()\n        )\n        with open(rules_file, \"r\") as f:\n            text_format.Parse(f.read(), migration_config)\n        for rule_proto in migration_config.rule:\n            from_path = (\n                _str_to_key_path(rule_proto.from_path)\n                if rule_proto.from_path\n                else None\n            )\n            to_path = (\n                _str_to_key_path(rule_proto.to_path)\n                if rule_proto.to_path\n                else None\n            )\n            rules.append((from_path, to_path))\n    return rules\n\n\ndef _format_value(value: Any) -> str:\n    if isinstance(value, (np.ndarray, jax.Array)):\n        return f\"{value.dtype}{value.shape}\"\n    return repr(value)\n\n\ndef _format_path_diff(\n    unhandled_source: Set[Tuple[str, ...]],\n    unhandled_dest: Set[Tuple[str, ...]],\n    old_paths: Dict[Tuple[str, ...], Any],\n    new_paths: Dict[Tuple[str, ...], Any],\n) -> str:\n    diff = []\n    for p in sorted(list(unhandled_source | unhandled_dest)):\n        p_str = \".\".join(p)\n        if p in unhandled_source:\n            diff.append(f\"- {p_str}: {_format_value(old_paths[p])}\")\n        if p in unhandled_dest:\n            diff.append(f\"+ {p_str}: {_format_value(new_paths[p])}\")\n    return \"\\n\".join(diff)\n\n\nclass Migration:\n    def __init__(self, old_state: Any, new_state: Any):\n        old_leaves = jax.tree_util.tree_leaves_with_path(old_state)\n        new_flat, self.new_treedef = jax.tree_util.tree_flatten_with_path(\n            new_state\n        )\n\n        self.old_paths: Dict[Tuple[str, ...], Any] = {\n            tuple_path_from_keypath(path): value for path, value in old_leaves\n        }\n        self.new_paths: Dict[Tuple[str, ...], Any] = {\n            tuple_path_from_keypath(path): value for path, value in new_flat\n        }\n        self.new_leaves: List[Any] = [value for _, value in new_flat]\n        self.new_path_to_idx: Dict[Tuple[str, ...], int] = {\n            tuple_path_from_keypath(path): i\n            for i, (path, _) in enumerate(new_flat)\n        }\n\n        self.source_paths: Set[Tuple[str, ...]] = set(self.old_paths.keys())\n        self.dest_paths: Set[Tuple[str, ...]] = set(self.new_path_to_idx.keys())\n\n        print(f\"{len(self.source_paths & self.dest_paths)} common keys\")\n        print(f\"{len(self.source_paths - self.dest_paths)} keys disappeared\")\n        print(f\"{len(self.dest_paths - self.source_paths)} keys appeared\")\n\n        self.errors: List[str] = []\n\n    def _apply_move_rule(\n        self, from_path: Tuple[str, ...], to_path: Tuple[str, ...]\n    ) -> None:\n        if from_path == to_path:\n            self.errors.append(\n                f\"from_path and to_path are the same: {from_path}\"\n            )\n            return\n\n        source_prefixed = {\n            p for p in self.source_paths if p[: len(from_path)] == from_path\n        }\n        dest_prefixed = {\n            p for p in self.dest_paths if p[: len(to_path)] == to_path\n        }\n\n        if not source_prefixed:\n            self.errors.append(f\"from_path {from_path} not found in old state\")\n        if not dest_prefixed:\n            self.errors.append(f\"to_path {to_path} not found in new state\")\n\n        for p in source_prefixed:\n            new_p = to_path + p[len(from_path) :]\n            if new_p in self.dest_paths:\n                idx = self.new_path_to_idx[new_p]\n                self.new_leaves[idx] = self.old_paths[p]\n                self.source_paths.remove(p)\n                self.dest_paths.remove(new_p)\n            else:\n                self.errors.append(f\"Path {new_p} not found in new state\")\n\n    def _apply_ignore_rule(self, from_path: Tuple[str, ...]) -> None:\n        source_prefixed = {\n            p for p in self.source_paths if p[: len(from_path)] == from_path\n        }\n        if not source_prefixed:\n            self.errors.append(f\"from_path {from_path} not found in old state\")\n        self.source_paths -= source_prefixed\n\n    def _apply_keep_rule(self, to_path: Tuple[str, ...]) -> None:\n        dest_prefixed = {\n            p for p in self.dest_paths if p[: len(to_path)] == to_path\n        }\n        if not dest_prefixed:\n            self.errors.append(f\"to_path {to_path} not found in new state\")\n        self.dest_paths -= dest_prefixed\n\n    def apply_rules(self, rules: List[Tuple[Any, Any]]) -> None:\n        for from_path, to_path in rules:\n            if from_path and to_path:\n                self._apply_move_rule(from_path, to_path)\n            elif from_path:\n                self._apply_ignore_rule(from_path)\n            elif to_path:\n                self._apply_keep_rule(to_path)\n\n    def run(self, rules: List[Tuple[Any, Any]]) -> Any:\n        self.apply_rules(rules)\n\n        # Copy remaining paths\n        copied_paths = self.source_paths & self.dest_paths\n        for p in copied_paths:\n            idx = self.new_path_to_idx[p]\n            self.new_leaves[idx] = self.old_paths[p]\n        self.source_paths -= copied_paths\n        self.dest_paths -= copied_paths\n\n        unhandled_source = self.source_paths\n        unhandled_dest = self.dest_paths\n\n        if unhandled_source or unhandled_dest:\n            self.errors.append(\n                \"Unmapped paths:\\n\"\n                + _format_path_diff(\n                    unhandled_source,\n                    unhandled_dest,\n                    self.old_paths,\n                    self.new_paths,\n                )\n            )\n\n        if self.errors:\n            raise ValueError(\"\\n\".join(self.errors))\n\n        return getattr(self.new_treedef, \"unflatten\")(self.new_leaves)\n\n\ndef _save_checkpoint(\n    migrated_state: Any,\n    new_checkpoint_path: str,\n    new_checkpoint_step: int,\n    overwrite: bool,\n) -> None:\n    manager = ocp.CheckpointManager(\n        new_checkpoint_path,\n        ocp.PyTreeCheckpointer(),\n        options=ocp.CheckpointManagerOptions(\n            create=True, save_interval_steps=1, todelete_subdir=\"trash\"\n        ),\n    )\n\n    if new_checkpoint_step in manager.all_steps():\n        if overwrite:\n            manager.delete(new_checkpoint_step)\n            manager.wait_until_finished()\n        else:\n            raise ValueError(\n                f\"Checkpoint already exists at {new_checkpoint_step} in \"\n                f\"{new_checkpoint_path}. \"\n                \"Use --overwrite to overwrite.\"\n            )\n\n    manager.save(new_checkpoint_step, migrated_state)\n    manager.wait_until_finished()\n    print(\n        f\"New checkpoint saved successfully to {new_checkpoint_path} at step \"\n        f\"{new_checkpoint_step}.\"\n    )\n\n\ndef _dump_paths(paths: Iterable[Tuple[str, ...]], field: str) -> None:\n    def key(p: Tuple[str, ...]) -> tuple:\n        return tuple(int(c) if c.isdigit() else c for c in p)\n\n    for path in sorted(paths, key=key):\n        print(f'rule {{ {field}: \"{\".\".join(path)}\" }}')\n\n\ndef migrate_checkpoint(\n    config: str,\n    new_checkpoint: str | None,\n    overwrite: bool,\n    rules_file: str | None,\n    serialized_model: bool,\n    checkpoint_step: int | None,\n    new_checkpoint_step: int | None,\n    dump_source_paths: bool = False,\n    dump_destination_paths: bool = False,\n) -> None:\n    \"\"\"Migrates a checkpoint to a new training state.\"\"\"\n    root_config = root_config_pb2.RootConfig()\n    with open(config, \"r\") as f:\n        text_format.Parse(f.read(), root_config)\n\n    new_state = _load_new_state(root_config, serialized_model)\n    old_state, old_checkpoint_step = _load_old_state(\n        root_config.training.checkpoint.path, checkpoint_step\n    )\n    rules = load_migration_rules(rules_file)\n\n    migration = Migration(old_state, new_state)\n\n    if dump_source_paths:\n        _dump_paths(migration.old_paths.keys(), \"from_path\")\n    if dump_destination_paths:\n        _dump_paths(migration.new_paths.keys(), \"to_path\")\n    if dump_source_paths or dump_destination_paths:\n        return\n\n    migrated_state = migration.run(rules)\n\n    if new_checkpoint:\n        checkpoint_path = new_checkpoint\n    elif overwrite:\n        checkpoint_path = root_config.training.checkpoint.path\n    else:\n        print(\"Migration check successful.\")\n        return\n\n    if new_checkpoint_step is None:\n        new_checkpoint_step = old_checkpoint_step\n\n    _save_checkpoint(\n        migrated_state, checkpoint_path, new_checkpoint_step, overwrite\n    )\n"
  },
  {
    "path": "src/lczero_training/training/optimizer.py",
    "content": "from functools import partial\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom flax import nnx\n\nfrom lczero_training.training.utils import make_weights_mask\nfrom proto.training_config_pb2 import OptimizerConfig\n\n_STATES_WITH_COUNT = (\n    optax.ScaleByAdamState,\n    optax.ScaleByScheduleState,\n)\n\n\ndef update_optimizer_step(\n    opt_state: optax.OptState, step: int\n) -> optax.OptState:\n    \"\"\"Updates all step counters in the optimizer state tree.\"\"\"\n    step_array = jnp.array(step, dtype=jnp.int32)\n\n    def has_count(x: object) -> bool:\n        return hasattr(x, \"_fields\") and \"count\" in x._fields\n\n    def update_count(x: optax.OptState) -> optax.OptState:\n        if not has_count(x):\n            return x\n        assert isinstance(x, _STATES_WITH_COUNT), (\n            f\"Unexpected state type with 'count' field: {type(x).__name__}\"\n        )\n        return x._replace(count=step_array)\n\n    return jax.tree_util.tree_map(update_count, opt_state, is_leaf=has_count)\n\n\ndef make_gradient_transformation(\n    config: OptimizerConfig,\n    *,\n    max_grad_norm: float | None = None,\n    lr_schedule: optax.Schedule,\n) -> optax.GradientTransformation:\n    if config.HasField(\"nadamw\"):\n        nadamw = config.nadamw\n        tx = optax.nadamw(\n            lr_schedule,\n            b1=nadamw.beta_1,\n            b2=nadamw.beta_2,\n            eps=nadamw.epsilon,\n            weight_decay=nadamw.weight_decay,\n            mask=partial(make_weights_mask, nadamw.decay_selector),\n        )\n    elif config.HasField(\"nadam\"):\n        nadam = config.nadam\n        tx = optax.nadam(\n            lr_schedule,\n            b1=nadam.beta_1,\n            b2=nadam.beta_2,\n            eps=nadam.epsilon,\n        )\n    elif config.HasField(\"sgd\"):\n        sgd = config.sgd\n        tx = optax.sgd(\n            lr_schedule,\n            momentum=sgd.momentum if sgd.momentum else None,\n            nesterov=sgd.nesterov,\n        )\n    else:\n        raise ValueError(\n            \"Unsupported optimizer type: {}\".format(\n                config.WhichOneof(\"optimizer_type\")\n            )\n        )\n    if max_grad_norm is not None and max_grad_norm > 0:\n        tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), tx)\n    if config.HasField(\"freeze_selector\"):\n        freeze_mask = partial(make_weights_mask, config.freeze_selector)\n\n        def trainable_mask(p: nnx.State) -> nnx.State:\n            return jax.tree.map(lambda x: not x, freeze_mask(p))\n\n        tx = optax.chain(\n            optax.masked(tx, trainable_mask),\n            optax.masked(optax.set_to_zero(), freeze_mask),\n        )\n    return tx\n"
  },
  {
    "path": "src/lczero_training/training/overfit.py",
    "content": "\"\"\"Overfitting utility for quickly validating training setup.\"\"\"\n\nimport csv\nimport logging\nfrom contextlib import suppress\nfrom functools import partial\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import nnx\nfrom google.protobuf import text_format\nfrom jax import tree_util\n\nfrom lczero_training.dataloader import DataLoader, make_dataloader\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.lr_schedule import make_lr_schedule\nfrom lczero_training.training.optimizer import make_gradient_transformation\nfrom lczero_training.training.state import (\n    TrainingBatch,\n    TrainingSample,\n    TrainingState,\n)\nfrom lczero_training.training.training import Training\nfrom proto.root_config_pb2 import RootConfig\n\nlogger = logging.getLogger(__name__)\n\n\ndef _stop_loader(loader: DataLoader) -> None:\n    with suppress(Exception):\n        loader.stop()\n\n\ndef _prepare_batch(batch_tuple: tuple) -> TrainingBatch:\n    # DataLoader now returns tuple: (inputs, probabilities, values)\n    return TrainingBatch(\n        inputs=jnp.asarray(batch_tuple[0]),\n        probabilities=jnp.asarray(batch_tuple[1]),\n        values=jnp.asarray(batch_tuple[2]),\n    )\n\n\ndef _make_eval_step(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> Any:\n    @partial(nnx.jit, static_argnames=())\n    def eval_step(\n        model_state: nnx.State, batch: TrainingBatch\n    ) -> tuple[jax.Array, Any]:\n        model = nnx.merge(graphdef, model_state)\n\n        def loss_for_batch(\n            model_arg: LczeroModel, sample_arg: TrainingSample\n        ) -> tuple[jax.Array, Any]:\n            return loss_fn(model_arg, sample_arg)\n\n        loss_vfn = jax.vmap(loss_for_batch, in_axes=(None, 0), out_axes=0)\n        # vmap automatically distributes TrainingBatch over batch dimension,\n        # calling loss_for_batch with TrainingSample (single samples).\n        per_sample_loss, unweighted_losses = loss_vfn(model, batch)  # type: ignore[arg-type]\n        mean_loss = jnp.mean(per_sample_loss)\n        mean_unweighted = tree_util.tree_map(jnp.mean, unweighted_losses)\n        return mean_loss, mean_unweighted\n\n    return eval_step\n\n\ndef overfit(\n    *,\n    config_filename: str,\n    num_steps: int,\n    coin_flip: bool = False,\n    csv_file: str | None = None,\n) -> None:\n    \"\"\"Runs an overfitting loop to validate training.\"\"\"\n\n    if num_steps <= 0:\n        raise ValueError(\"num_steps must be a positive integer\")\n\n    if jax.device_count() > 1:\n        raise ValueError(\n            f\"Overfit utility does not support multi-GPU training. \"\n            f\"Detected {jax.device_count()} devices. \"\n            f\"Please set CUDA_VISIBLE_DEVICES to use only one GPU.\"\n        )\n\n    config = RootConfig()\n    logger.info(\"Reading configuration from proto file\")\n    with open(config_filename, \"r\") as config_file:\n        text_format.Parse(config_file.read(), config)\n\n    logger.info(\"Creating data loader and fetching batches\")\n    loader = make_dataloader(config.data_loader)\n    try:\n        batch_a = loader.get_next()\n        batch_b = loader.get_next() if coin_flip else None\n    finally:\n        _stop_loader(loader)\n\n    prepared_batch_a = _prepare_batch(batch_a)\n    prepared_batch_b = _prepare_batch(batch_b) if batch_b is not None else None\n\n    logger.info(\"Creating training state from configuration\")\n    training_state = TrainingState.new_from_config(\n        model_config=config.model,\n        training_config=config.training,\n    )\n\n    graphdef, _ = nnx.split(\n        LczeroModel(config=config.model, rngs=nnx.Rngs(params=42))\n    )\n\n    jit_state = training_state.jit_state\n    lr_sched = make_lr_schedule(config.training.lr_schedule)\n    optimizer_tx = make_gradient_transformation(\n        config.training.optimizer,\n        max_grad_norm=getattr(config.training, \"max_grad_norm\", 0.0),\n        lr_schedule=lr_sched,\n    )\n\n    loss_fn = LczeroLoss(config=config.training.losses)\n    training = Training(\n        optimizer_tx=optimizer_tx,\n        graphdef=graphdef,\n        loss_fn=loss_fn,\n    )\n    eval_step = _make_eval_step(graphdef, loss_fn)\n\n    csv_handle = None\n    csv_writer: Any | None = None\n    if csv_file is not None:\n        logger.info(\"Writing overfit results to %s\", csv_file)\n        csv_handle = open(csv_file, \"w\", newline=\"\")\n        csv_writer = csv.writer(csv_handle)\n        csv_writer.writerow(\n            [\n                \"step\",\n                \"train_batch\",\n                \"train_loss\",\n                \"train_unweighted\",\n                \"eval_batch\",\n                \"eval_loss\",\n                \"eval_unweighted\",\n            ]\n        )\n        csv_handle.flush()\n\n    def log_step(\n        *,\n        step_value: int,\n        train_batch_name: str,\n        train_loss: float,\n        train_unweighted: Any,\n        eval_batch_name: str | None,\n        eval_loss: float | None,\n        eval_unweighted: Any | None,\n    ) -> None:\n        if eval_batch_name is None or eval_loss is None:\n            logger.info(\n                \"Step %d: batch=%s train_loss=%f, unweighted_losses=%s\",\n                step_value,\n                train_batch_name,\n                train_loss,\n                train_unweighted,\n            )\n        else:\n            logger.info(\n                (\n                    \"Step %d: trained %s train_loss=%f, unweighted_losses=%s; \"\n                    \"evaluated %s eval_loss=%f, eval_unweighted=%s\"\n                ),\n                step_value,\n                train_batch_name,\n                train_loss,\n                train_unweighted,\n                eval_batch_name,\n                eval_loss,\n                eval_unweighted,\n            )\n\n        if csv_writer is not None and csv_handle is not None:\n            csv_writer.writerow(\n                [\n                    step_value,\n                    train_batch_name,\n                    train_loss,\n                    repr(train_unweighted),\n                    eval_batch_name or \"\",\n                    \"\" if eval_loss is None else eval_loss,\n                    \"\" if eval_unweighted is None else repr(eval_unweighted),\n                ]\n            )\n            csv_handle.flush()\n\n    try:\n        if coin_flip:\n            if prepared_batch_b is None:\n                raise RuntimeError(\n                    \"Coin flip mode requires two batches but only one was fetched\"\n                )\n\n            logger.info(\n                \"Starting coin-flip overfit: %d steps on batch A then %d on batch B\",\n                num_steps,\n                num_steps,\n            )\n\n            def run_phase(\n                train_batch: TrainingBatch,\n                train_name: str,\n                eval_batch: TrainingBatch,\n                eval_name: str,\n            ) -> None:\n                nonlocal jit_state\n                for _ in range(num_steps):\n                    jit_state, metrics = training.train_step(\n                        optimizer_tx,\n                        jit_state,\n                        train_batch,\n                    )\n                    loss = metrics[\"loss\"]\n                    unweighted_losses = metrics[\"unweighted_losses\"]\n                    loss_value, unweighted_host = jax.device_get(\n                        (loss, unweighted_losses)\n                    )\n                    loss_value = float(np.asarray(loss_value))\n                    unweighted_host = tree_util.tree_map(\n                        lambda x: float(np.asarray(x)), unweighted_host\n                    )\n\n                    eval_loss, eval_unweighted = eval_step(\n                        jit_state.model_state, eval_batch\n                    )\n                    eval_loss, eval_unweighted = jax.device_get(\n                        (eval_loss, eval_unweighted)\n                    )\n                    eval_loss_value = float(np.asarray(eval_loss))\n                    eval_unweighted_host = tree_util.tree_map(\n                        lambda x: float(np.asarray(x)), eval_unweighted\n                    )\n\n                    step_value = int(\n                        np.asarray(jax.device_get(jit_state.step)).flat[0]\n                    )\n                    log_step(\n                        step_value=step_value,\n                        train_batch_name=train_name,\n                        train_loss=loss_value,\n                        train_unweighted=unweighted_host,\n                        eval_batch_name=eval_name,\n                        eval_loss=eval_loss_value,\n                        eval_unweighted=eval_unweighted_host,\n                    )\n\n            run_phase(prepared_batch_a, \"A\", prepared_batch_b, \"B\")\n            run_phase(prepared_batch_b, \"B\", prepared_batch_a, \"A\")\n        else:\n            logger.info(\"Starting overfit loop for %d steps\", num_steps)\n            for _ in range(num_steps):\n                jit_state, metrics = training.train_step(\n                    optimizer_tx,\n                    jit_state,\n                    prepared_batch_a,\n                )\n                loss = metrics[\"loss\"]\n                unweighted_losses = metrics[\"unweighted_losses\"]\n                loss_value, unweighted_host = jax.device_get(\n                    (loss, unweighted_losses)\n                )\n                loss_value = float(np.asarray(loss_value))\n                unweighted_host = tree_util.tree_map(\n                    lambda x: float(np.asarray(x)), unweighted_host\n                )\n                step_value = int(\n                    np.asarray(jax.device_get(jit_state.step)).flat[0]\n                )\n                log_step(\n                    step_value=step_value,\n                    train_batch_name=\"single\",\n                    train_loss=loss_value,\n                    train_unweighted=unweighted_host,\n                    eval_batch_name=None,\n                    eval_loss=None,\n                    eval_unweighted=None,\n                )\n    finally:\n        if csv_handle is not None:\n            csv_handle.close()\n"
  },
  {
    "path": "src/lczero_training/training/state.py",
    "content": "import dataclasses\nimport logging\nfrom typing import Any, Optional, Union\n\nimport jax\nimport jax.numpy as jnp\nimport jax.sharding as jshard\nimport numpy as np\nimport optax\nfrom flax import nnx\nfrom flax.struct import dataclass\n\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.lr_schedule import make_lr_schedule\nfrom lczero_training.training.optimizer import (\n    make_gradient_transformation,\n    update_optimizer_step,\n)\nfrom proto.model_config_pb2 import ModelConfig\nfrom proto.training_config_pb2 import TrainingConfig\n\nlogger = logging.getLogger(__name__)\n\n\n@jax.tree_util.register_dataclass\n@dataclasses.dataclass\nclass TrainingSample:\n    \"\"\"Single training sample without batch dimension.\n\n    Used for vmap over individual samples in loss computation.\n\n    Fields:\n        inputs: Input planes tensor [112, 8, 8]\n        probabilities: Policy probabilities tensor [1858]\n        values: Combined values tensor [6, 3] where:\n            - Index 0: result [result_q, result_d, plies_left]\n            - Index 1: best [best_q, best_d, best_m]\n            - Index 2: played [played_q, played_d, played_m]\n            - Index 3: orig [orig_q, orig_d, orig_m] (may contain NaN)\n            - Index 4: root [root_q, root_d, root_m]\n            - Index 5: st [q_st, d_st, NaN]\n    \"\"\"\n\n    inputs: jax.Array\n    probabilities: jax.Array\n    values: jax.Array\n\n\n@jax.tree_util.register_dataclass\n@dataclasses.dataclass\nclass TrainingBatch:\n    \"\"\"Batch of training data with inputs, probabilities, and values tensors.\n\n    Fields:\n        inputs: Input planes tensor [batch, 112, 8, 8]\n        probabilities: Policy probabilities tensor [batch, 1858]\n        values: Combined values tensor [batch, 6, 3] where:\n            - Index 0: result [result_q, result_d, plies_left]\n            - Index 1: best [best_q, best_d, best_m]\n            - Index 2: played [played_q, played_d, played_m]\n            - Index 3: orig [orig_q, orig_d, orig_m] (may contain NaN)\n            - Index 4: root [root_q, root_d, root_m]\n            - Index 5: st [q_st, d_st, NaN]\n    \"\"\"\n\n    inputs: Union[jax.Array, jshard.NamedSharding]\n    probabilities: Union[jax.Array, jshard.NamedSharding]\n    values: Union[jax.Array, jshard.NamedSharding]\n\n    @classmethod\n    def from_tuple(\n        cls, tensor_tuple: tuple[np.ndarray, ...]\n    ) -> \"TrainingBatch\":\n        \"\"\"Create TrainingBatch from tuple returned by DataLoader.\"\"\"\n        if len(tensor_tuple) != 3:\n            raise ValueError(\n                f\"Expected tuple of 3 tensors, got {len(tensor_tuple)}\"\n            )\n        return cls(\n            inputs=jnp.asarray(tensor_tuple[0]),\n            probabilities=jnp.asarray(tensor_tuple[1]),\n            values=jnp.asarray(tensor_tuple[2]),\n        )\n\n\n@dataclass\nclass JitTrainingState:\n    step: int\n    model_state: nnx.State\n    opt_state: Optional[optax.OptState]\n    # SWA state mirrors model_state structure when enabled; None otherwise.\n    # Marked non-pytree to exclude from JIT/pjit inputs and device transfers.\n    swa_state: Optional[nnx.State]\n    # Effective number of model snapshots accumulated into SWA (can be fractional).\n    num_averages: float\n\n    def replace(self, **changes: Any) -> \"JitTrainingState\":\n        \"\"\"Returns a new instance of the class with the specified changes.\"\"\"\n        return dataclasses.replace(self, **changes)\n\n\n@dataclass\nclass TrainingState:\n    jit_state: JitTrainingState\n    # Last chunk source that was available when the last epoch started training.\n    num_heads: int\n    last_chunk_source: str = \"\"\n\n    def replace(self, **changes: Any) -> \"TrainingState\":\n        \"\"\"Returns a new instance of the class with the specified changes.\"\"\"\n        return dataclasses.replace(self, **changes)\n\n    def with_updated_step(self, step: int) -> \"TrainingState\":\n        \"\"\"Returns a copy with updated step in both jit_state and optimizer.\"\"\"\n        updated_opt_state = (\n            update_optimizer_step(self.jit_state.opt_state, step)\n            if self.jit_state.opt_state is not None\n            else None\n        )\n        return self.replace(\n            jit_state=self.jit_state.replace(\n                step=step,\n                opt_state=updated_opt_state,\n            )\n        )\n\n    @staticmethod\n    def new_from_config(\n        model_config: ModelConfig, training_config: TrainingConfig\n    ) -> \"TrainingState\":\n        rngs = nnx.Rngs(params=42)\n        model_state = nnx.state(LczeroModel(config=model_config, rngs=rngs))\n        lr_sched = make_lr_schedule(training_config.lr_schedule)\n        opt_state = make_gradient_transformation(\n            training_config.optimizer,\n            max_grad_norm=getattr(training_config, \"max_grad_norm\", 0.0),\n            lr_schedule=lr_sched,\n        ).init(model_state)\n        jit_state = JitTrainingState(\n            step=0,\n            model_state=model_state,\n            opt_state=opt_state,\n            swa_state=model_state,\n            num_averages=0.0,\n        )\n        return TrainingState(\n            jit_state=jit_state,\n            num_heads=model_config.encoder.heads,\n        )\n"
  },
  {
    "path": "src/lczero_training/training/tensorboard.py",
    "content": "\"\"\"Utilities for writing training metrics to TensorBoard event files.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom collections.abc import Mapping\nfrom typing import Any, Dict\n\nimport jax\nimport numpy as np\nfrom tensorboardX import SummaryWriter\n\nlogger = logging.getLogger(__name__)\n\nMetricsDict = Dict[str, Any]\n\n\ndef _to_ndarray(value: Any) -> np.ndarray:\n    try:\n        return np.asarray(jax.device_get(value))\n    except TypeError:\n        return np.asarray(value)\n\n\ndef _to_scalar(value: Any) -> float | None:\n    array = _to_ndarray(value)\n    if array.ndim == 0 or array.size == 1:\n        return float(array.reshape(()))\n    logger.warning(\n        \"Skipping non-scalar metric with shape %s when logging to TensorBoard.\",\n        array.shape,\n    )\n    return None\n\n\ndef _flatten_metrics(\n    metrics: Mapping[str, Any], prefix: str = \"\"\n) -> Dict[str, float]:\n    scalars: Dict[str, float] = {}\n    for key, value in metrics.items():\n        tag = f\"{prefix}{key}\" if prefix else key\n        if isinstance(value, Mapping):\n            scalars.update(_flatten_metrics(value, f\"{tag}/\"))\n            continue\n        scalar = _to_scalar(value)\n        if scalar is not None:\n            scalars[tag] = scalar\n    return scalars\n\n\ndef _to_step(step: Any) -> int:\n    return int(_to_ndarray(step).reshape(()))\n\n\nclass TensorboardLogger:\n    \"\"\"Writes scalar training metrics to TensorBoard.\"\"\"\n\n    def __init__(self, logdir: str) -> None:\n        self._writer = SummaryWriter(logdir)\n\n    def log(self, step: int, metrics: MetricsDict) -> None:\n        global_step = _to_step(step)\n        for tag, value in _flatten_metrics(metrics).items():\n            self._writer.add_scalar(\n                tag=tag, scalar_value=value, global_step=global_step\n            )\n\n    def close(self) -> None:\n        self._writer.close()\n"
  },
  {
    "path": "src/lczero_training/training/test_lr_schedule.py",
    "content": "from typing import Callable, List\n\nimport jax.numpy as jnp\nimport pytest\n\nfrom lczero_training.training.lr_schedule import make_lr_schedule\nfrom proto import training_config_pb2 as pb\n\n\ndef _sched(\n    schedules: List[pb.LrSchedule],\n) -> Callable[[jnp.ndarray], jnp.ndarray]:\n    return make_lr_schedule(schedules)\n\n\ndef _val(s: Callable[[jnp.ndarray], jnp.ndarray], t: int | float) -> float:\n    return float(s(jnp.asarray(t, dtype=jnp.float32)))\n\n\ndef test_rule_selection_by_starting_step() -> None:\n    r0 = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[5],\n        lr=[0.1, 0.2],\n    )\n    r1 = pb.LrSchedule(\n        starting_step=10,\n        duration_steps=[5],\n        lr=[0.5, 0.5],\n    )\n    sched = _sched([r0, r1])\n    assert _val(sched, 0) == pytest.approx(0.1)\n    assert _val(sched, 9) == pytest.approx(0.2)\n    assert _val(sched, 10) == pytest.approx(0.5)\n    assert _val(sched, 100) == pytest.approx(0.5)\n\n\ndef test_default_constant_transition_and_tail() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[5],\n        lr=[0.3, 0.8],  # no transition specified -> CONSTANT\n    )\n    sched = _sched([r])\n    for t in range(5):\n        assert _val(sched, t) == pytest.approx(0.3)\n    # Beyond period (no loop) yields last lr\n    assert _val(sched, 6) == pytest.approx(0.8)\n\n\ndef test_linear_then_hold() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[3, 7],\n        lr=[0.0, 0.9, 0.9],\n        transition=[pb.LrSchedule.Transition.LINEAR],\n    )\n    sched = _sched([r])\n    assert _val(sched, 0) == pytest.approx(0.0)\n    assert _val(sched, 1) == pytest.approx(0.3)\n    assert _val(sched, 2) == pytest.approx(0.6)\n    assert _val(sched, 3) == pytest.approx(0.9)\n    assert _val(sched, 8) == pytest.approx(0.9)\n\n\ndef test_looping_constant_segments() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[3, 2],\n        lr=[1.0, 2.0, 3.0],\n        loop=True,\n    )\n    sched = _sched([r])\n    assert _val(sched, 0) == pytest.approx(1.0)\n    assert _val(sched, 2) == pytest.approx(1.0)\n    assert _val(sched, 3) == pytest.approx(2.0)\n    assert _val(sched, 5) == pytest.approx(1.0)  # 5 % (3+2) == 0\n\n\ndef test_zero_duration_is_skipped() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[0, 5],\n        lr=[1.0, 2.0, 3.0],\n        transition=[\n            pb.LrSchedule.Transition.LINEAR,\n            pb.LrSchedule.Transition.LINEAR,\n        ],\n    )\n    sched = _sched([r])\n    # Interpolates over the second interval [2.0 -> 3.0]\n    assert _val(sched, 0) == pytest.approx(2.0)\n    assert _val(sched, 2) == pytest.approx(2.4)\n    assert _val(sched, 5) == pytest.approx(3.0)\n\n\ndef test_chain_zero_durations_then_linear() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[0, 0, 0, 5],\n        lr=[1.0, 2.0, 3.0, 4.0, 5.0],\n        transition=[\n            pb.LrSchedule.Transition.LINEAR,\n            pb.LrSchedule.Transition.LINEAR,\n            pb.LrSchedule.Transition.LINEAR,\n            pb.LrSchedule.Transition.LINEAR,\n        ],\n    )\n    sched = _sched([r])\n    # Should use last interval [4.0 -> 5.0] linearly across 5 steps\n    assert _val(sched, 0) == pytest.approx(4.0)\n    assert _val(sched, 2) == pytest.approx(4.4)\n    assert _val(sched, 4) == pytest.approx(4.8)\n\n\ndef test_cosine() -> None:\n    r = pb.LrSchedule(\n        starting_step=0,\n        duration_steps=[4],\n        lr=[0.0, 1.0],\n        transition=[pb.LrSchedule.Transition.COSINE],\n    )\n    sched = _sched([r])\n    # t=0 -> 0.0, t=0.5 -> 0.5, t=1.0 -> 1.0 approximately\n    assert _val(sched, 0) == pytest.approx(0.0)\n    assert _val(sched, 2) == pytest.approx(0.5, abs=1e-6)\n    assert _val(sched, 4) == pytest.approx(1.0)\n\n\ndef test_before_first_rule_uses_earliest_first_lr() -> None:\n    r = pb.LrSchedule(\n        starting_step=5,\n        duration_steps=[3],\n        lr=[0.1, 0.2],\n    )\n    sched = _sched([r])\n    # Before the first rule starts, schedule returns the first lr of earliest rule.\n    assert _val(sched, 0) == pytest.approx(0.1)\n"
  },
  {
    "path": "src/lczero_training/training/training.py",
    "content": "import dataclasses\nimport logging\nfrom datetime import datetime\nfrom functools import partial\nfrom typing import Any, Callable, Dict, Generator, Optional, Tuple, cast\n\nimport jax\nimport jax.numpy as jnp\nimport jax.sharding as jshard\nimport numpy as np\nimport optax\nfrom flax import nnx\nfrom jax import tree_util\nfrom jax.sharding import PartitionSpec as P\n\nfrom lczero_training.dataloader import DataLoader\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.state import (\n    JitTrainingState,\n    TrainingBatch,\n    TrainingSample,\n)\nfrom proto import training_config_pb2 as training_config_pb2\n\nMetricsDict = Dict[str, Any]\n\n\n@dataclasses.dataclass\nclass StepHookData:\n    \"\"\"Data passed to the step hook callback during training.\"\"\"\n\n    global_step: int\n    local_step: int\n    steps_per_epoch: int\n    metrics: MetricsDict\n    jit_state: JitTrainingState\n\n\nStepHook = Callable[[StepHookData], None]\n\nlogger = logging.getLogger(__name__)\n\n\ndef from_dataloader(\n    loader: DataLoader,\n) -> Generator[tuple[np.ndarray, ...], None, None]:\n    while True:\n        yield loader.get_next()\n\n\nclass Training:\n    optimizer_tx: optax.GradientTransformation\n    train_step: Callable[\n        [optax.GradientTransformation, JitTrainingState, TrainingBatch],\n        Tuple[JitTrainingState, MetricsDict],\n    ]\n    _swa_config: Optional[training_config_pb2.SWAConfig]\n    _dp_sharding: Optional[jshard.NamedSharding]\n\n    def __init__(\n        self,\n        optimizer_tx: optax.GradientTransformation,\n        graphdef: nnx.GraphDef,\n        loss_fn: LczeroLoss,\n        swa_config: Optional[training_config_pb2.SWAConfig] = None,\n    ):\n        self.optimizer_tx = optimizer_tx\n        self._swa_config = swa_config\n        self._dp_sharding = None\n\n        jit_kwargs: Dict[str, Any] = {\n            \"static_argnames\": (\"optimizer_tx\",),\n            \"donate_argnames\": (\"jit_state\",),\n        }\n        if jax.device_count() > 1:\n            num_devices = jax.device_count()\n            logger.info(\n                f\"Multi-GPU training enabled: {num_devices} devices detected\"\n            )\n            mesh = jshard.Mesh(jax.devices(), axis_names=(\"batch\",))\n            replicated = jshard.NamedSharding(mesh, P())\n            dp_sharding = jshard.NamedSharding(mesh, P(\"batch\"))\n            self._dp_sharding = dp_sharding\n\n            batch_sharding = TrainingBatch(\n                inputs=dp_sharding,\n                probabilities=dp_sharding,\n                values=dp_sharding,\n            )\n            in_shardings = (replicated, batch_sharding)\n            out_shardings = replicated\n\n            jit_kwargs[\"in_shardings\"] = in_shardings\n            jit_kwargs[\"out_shardings\"] = out_shardings\n\n        @partial(jax.jit, **jit_kwargs)\n        def _step(\n            optimizer_tx: optax.GradientTransformation,\n            jit_state: JitTrainingState,\n            batch: TrainingBatch,\n        ) -> Tuple[JitTrainingState, MetricsDict]:\n            model = nnx.merge(graphdef, jit_state.model_state)\n\n            def loss_for_grad(\n                model_arg: LczeroModel, sample_arg: TrainingSample\n            ) -> Tuple[jax.Array, Dict[str, jax.Array]]:\n                return loss_fn(model_arg, sample_arg)\n\n            loss_vfn = jax.vmap(\n                loss_for_grad,\n                in_axes=(None, 0),  # (model_arg, sample_arg)\n                out_axes=0,\n            )\n\n            def mean_loss_for_grad(\n                model_arg: LczeroModel, batch_arg: TrainingBatch\n            ) -> Tuple[jax.Array, Dict[str, jax.Array]]:\n                # vmap automatically distributes TrainingBatch over batch dimension,\n                # calling loss_for_grad with TrainingSample (single samples).\n                per_sample_data_loss, unweighted_losses = loss_vfn(\n                    model_arg,\n                    batch_arg,  # type: ignore[arg-type]\n                )\n                mean_loss = jnp.mean(per_sample_data_loss)\n                return mean_loss, unweighted_losses\n\n            grad_fn = nnx.value_and_grad(mean_loss_for_grad, has_aux=True)\n            (mean_loss, unweighted_losses), mean_grads = grad_fn(model, batch)\n            grad_norm = optax.global_norm(mean_grads)\n\n            assert jit_state.opt_state is not None\n            updates, new_opt_state = optimizer_tx.update(\n                mean_grads, jit_state.opt_state, jit_state.model_state\n            )\n            new_model_state = optax.apply_updates(\n                jit_state.model_state, updates\n            )\n\n            new_jit_state = jit_state.replace(\n                step=jit_state.step + 1,\n                model_state=new_model_state,\n                opt_state=new_opt_state,\n            )\n\n            mean_unweighted = tree_util.tree_map(jnp.mean, unweighted_losses)\n            metrics: MetricsDict = {\n                \"loss\": mean_loss,\n                \"unweighted_losses\": mean_unweighted,\n                \"grad_norm\": grad_norm,\n            }\n            return new_jit_state, metrics\n\n        self.train_step = cast(\n            Callable[\n                [optax.GradientTransformation, JitTrainingState, TrainingBatch],\n                Tuple[JitTrainingState, MetricsDict],\n            ],\n            _step,\n        )\n\n    @staticmethod\n    @jax.jit\n    def _swa_tree_map(\n        alpha: jax.Array,\n        beta: jax.Array,\n        swa_state: nnx.State,\n        model_state: nnx.State,\n    ) -> nnx.State:\n        return tree_util.tree_map(\n            lambda a, b: alpha * a + beta * b, swa_state, model_state\n        )\n\n    def update_swa(\n        self, jit_state: JitTrainingState, weight: float\n    ) -> JitTrainingState:\n        \"\"\"Update SWA using the provided weight for the current model.\n\n        Assumes `jit_state.swa_state` is initialized and `_swa_config` present.\n        \"\"\"\n        logger.info(\n            \"Updating SWA model, weight=%f, num_averages=%f\",\n            weight,\n            jit_state.num_averages,\n        )\n        assert self._swa_config is not None\n        assert jit_state.swa_state is not None\n        assert weight > 0.0\n        max_num_averages = self._swa_config.num_averages\n        denom = jit_state.num_averages + weight\n        alpha = jit_state.num_averages / denom\n        beta = weight / denom\n        new_swa_state = self._swa_tree_map(\n            jnp.array(alpha),\n            jnp.array(beta),\n            jit_state.swa_state,\n            jit_state.model_state,\n        )\n        new_num_averages = min(\n            max_num_averages, jit_state.num_averages + weight\n        )\n        return jit_state.replace(\n            swa_state=new_swa_state, num_averages=new_num_averages\n        )\n\n    def maybe_update_swa(\n        self,\n        jit_state: JitTrainingState,\n        steps_completed: int,\n        total_steps: int,\n    ) -> JitTrainingState:\n        \"\"\"Optionally update SWA based on configured schedule and epoch progress.\n\n        Returns the original jit_state when no update is scheduled.\n        \"\"\"\n        if self._swa_config is None:\n            return jit_state\n        period_steps = self._swa_config.period_steps\n        assert period_steps > 0\n        if steps_completed % period_steps == 0:\n            return self.update_swa(jit_state, 1.0)\n        if steps_completed == total_steps:\n            remainder = total_steps % period_steps\n            return self.update_swa(jit_state, remainder / period_steps)\n        return jit_state\n\n    def _validate_and_prepare_batch(\n        self, tensor_tuple: tuple[np.ndarray, ...]\n    ) -> TrainingBatch:\n        logger.info(\"Fetched batch from dataloader\")\n\n        # Convert tuple to TrainingBatch\n        batch = TrainingBatch.from_tuple(tensor_tuple)\n\n        # Ensure batch.inputs is jax.Array for shape access\n        assert isinstance(batch.inputs, jax.Array)\n        batch_size = batch.inputs.shape[0]\n        if self._dp_sharding is not None:\n            num_devices = jax.device_count()\n            if batch_size % num_devices != 0:\n                raise ValueError(\n                    f\"Batch size {batch_size} must be divisible by device \"\n                    f\"count {num_devices} for multi-GPU training. \"\n                    f\"Per-device batch size would be \"\n                    f\"{batch_size / num_devices:.2f}\"\n                )\n            per_device_batch_size = batch_size // num_devices\n            logger.info(\n                f\"Multi-GPU batch: {batch_size} total \"\n                f\"({per_device_batch_size} per device)\"\n            )\n\n        if self._dp_sharding is not None:\n            batch = jax.device_put(batch, self._dp_sharding)\n\n        return batch\n\n    def _log_step_metrics(\n        self,\n        step_value: int,\n        local_step: int,\n        num_steps: int,\n        metrics: MetricsDict,\n    ) -> None:\n        loss = float(metrics[\"loss\"])\n        unweighted_losses = {\n            k: float(v) for k, v in metrics[\"unweighted_losses\"].items()\n        }\n        grad_norm = float(metrics[\"grad_norm\"])\n        logger.info(\n            f\"Step {step_value} ({local_step}/{num_steps}), Loss: {loss}, \"\n            f\"Unweighted losses: {unweighted_losses}, Grad norm: {grad_norm}\"\n        )\n\n    def _execute_step_hook(\n        self,\n        step_hook: Optional[StepHook],\n        step_value: int,\n        local_step: int,\n        num_steps: int,\n        metrics: MetricsDict,\n        jit_state: JitTrainingState,\n    ) -> None:\n        if step_hook is None:\n            return\n        hook_data = StepHookData(\n            global_step=step_value,\n            local_step=local_step,\n            steps_per_epoch=num_steps,\n            metrics=metrics,\n            jit_state=jit_state,\n        )\n        step_hook(hook_data)\n\n    def run(\n        self,\n        jit_state: JitTrainingState,\n        datagen: Generator[tuple[np.ndarray, ...], None, None],\n        num_steps: int,\n        step_hook: Optional[StepHook] = None,\n        memory_profile_dir: Optional[str] = None,\n    ) -> JitTrainingState:\n        assert jit_state.opt_state is not None\n        if self._dp_sharding is not None:\n            replicated = jshard.NamedSharding(self._dp_sharding.mesh, P())\n            jit_state = jax.device_put(jit_state, replicated)\n        for local_step in range(num_steps):\n            logger.info(f\"Starting step {jit_state.step}\")\n            if memory_profile_dir is not None:\n                jax.profiler.save_device_memory_profile(\n                    f\"{memory_profile_dir}/\"\n                    f\"{datetime.now().strftime('%Y%m%d-%H%M%S')}\"\n                    f\"_before_{int(jit_state.step)}.prof\"\n                )\n            batch = self._validate_and_prepare_batch(next(datagen))\n            jit_state, metrics = self.train_step(\n                self.optimizer_tx, jit_state, batch\n            )\n            step_value = int(\n                np.asarray(jax.device_get(jit_state.step)).reshape(())\n            )\n            jit_state = self.maybe_update_swa(\n                jit_state, local_step + 1, num_steps\n            )\n            self._execute_step_hook(\n                step_hook, step_value, local_step, num_steps, metrics, jit_state\n            )\n            self._log_step_metrics(step_value, local_step, num_steps, metrics)\n        return jit_state\n"
  },
  {
    "path": "src/lczero_training/training/tune_lr.py",
    "content": "import csv\nimport logging\nimport sys\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Callable, Dict, List, Tuple, cast\n\nimport jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport optax\nimport orbax.checkpoint as ocp\nfrom flax import nnx\nfrom google.protobuf import text_format\nfrom jax import tree_util\n\nfrom lczero_training.dataloader import make_dataloader\nfrom lczero_training.model.loss_function import LczeroLoss\nfrom lczero_training.model.model import LczeroModel\nfrom lczero_training.training.state import TrainingState\nfrom proto.root_config_pb2 import RootConfig\n\nfrom .training import Training, from_dataloader\n\nlogger = logging.getLogger(__name__)\n\n\ndef _prepare_batch(batch_tuple: tuple) -> Dict:\n    # DataLoader now returns tuple: (inputs, probabilities, values)\n    return {\n        \"inputs\": batch_tuple[0],\n        \"probabilities\": batch_tuple[1],\n        \"values\": batch_tuple[2],\n    }\n\n\ndef _make_optimizer_with_schedule(\n    training_state: TrainingState,\n    config: RootConfig,\n    schedule: optax.Schedule,\n) -> optax.GradientTransformation:\n    max_grad_norm = getattr(config.training, \"max_grad_norm\", 0.0)\n    opt_config = config.training.optimizer\n\n    if opt_config.HasField(\"nadamw\"):\n        conf = opt_config.nadamw\n        tx: optax.GradientTransformation = optax.nadamw(\n            schedule,\n            b1=conf.beta_1,\n            b2=conf.beta_2,\n            eps=conf.epsilon,\n            weight_decay=conf.weight_decay,\n        )\n    else:\n        raise ValueError(\n            f\"Unsupported optimizer type: {opt_config.WhichOneof('optimizer_type')}\"\n        )\n\n    if max_grad_norm > 0:\n        tx = optax.chain(optax.clip_by_global_norm(max_grad_norm), tx)\n\n    if training_state.jit_state.opt_state is None:\n        raise ValueError(\"Optimizer state must be available in the checkpoint.\")\n\n    return tx\n\n\ndef _make_eval_step(\n    graphdef: nnx.GraphDef, loss_fn: LczeroLoss\n) -> Callable[[nnx.State, Dict], jax.Array]:\n    @partial(nnx.jit, static_argnames=())\n    def eval_step(model_state: nnx.State, batch: Dict) -> jax.Array:\n        model = nnx.merge(graphdef, model_state)\n\n        def calculate_loss(\n            model_arg: LczeroModel, batch_arg: Dict\n        ) -> Tuple[jax.Array, Dict[str, jax.Array]]:\n            return loss_fn(model_arg, **batch_arg)\n\n        loss_vfn = jax.vmap(calculate_loss, in_axes=(None, 0), out_axes=0)\n        per_sample_data_loss, _ = loss_vfn(model, batch)\n        return jnp.mean(per_sample_data_loss)\n\n    return cast(Callable[[nnx.State, Dict], jax.Array], eval_step)\n\n\ndef _plot_results(results: List[Tuple[float, float]], plot_output: str) -> None:\n    logger.info(\"Saving plot to %s\", plot_output)\n    lrs, losses = zip(*results)\n    plt.figure(figsize=(10, 6))\n    plt.plot(lrs, losses, marker=\"o\")\n    plt.xscale(\"log\")\n    plt.xlabel(\"Learning Rate\")\n    plt.ylabel(\"Loss\")\n    plt.title(\"Learning Rate Finder\")\n    plt.grid(True, which=\"both\", linestyle=\"--\")\n    plt.tight_layout()\n    plt.savefig(plot_output, dpi=150)\n    plt.close()\n\n\ndef tune_lr(\n    *,\n    config_filename: str,\n    start_lr: float,\n    num_steps: int,\n    multiplier: float = 1.01,\n    warmup_steps: int = 0,\n    warmup_lr: float | None = None,\n    csv_output: str | None = None,\n    plot_output: str | None = None,\n    num_test_batches: int = 0,\n) -> None:\n    if num_steps <= 0 or start_lr <= 0 or multiplier <= 0 or warmup_steps < 0:\n        logger.error(\n            \"num_steps, start_lr, and multiplier must be positive, \"\n            \"and warmup_steps non-negative.\"\n        )\n        sys.exit(1)\n    if warmup_steps > 0 and (warmup_lr is None or warmup_lr <= 0):\n        logger.error(\"warmup_lr must be a positive value when warmup_steps > 0\")\n        sys.exit(1)\n\n    config = RootConfig()\n    logger.info(\"Reading configuration from %s\", config_filename)\n    with open(config_filename, \"r\") as f:\n        text_format.Parse(f.read(), config)\n\n    if not config.training.checkpoint.path:\n        logger.error(\"Checkpoint path must be set in the configuration.\")\n        sys.exit(1)\n\n    checkpoint_mgr = ocp.CheckpointManager(\n        config.training.checkpoint.path,\n        options=ocp.CheckpointManagerOptions(create=False),\n    )\n\n    logger.info(\"Creating state from configuration\")\n    empty_state = TrainingState.new_from_config(\n        model_config=config.model, training_config=config.training\n    )\n\n    logger.info(\"Restoring checkpoint from %s\", config.training.checkpoint.path)\n    training_state = checkpoint_mgr.restore(\n        checkpoint_mgr.latest_step(), args=ocp.args.PyTreeRestore(empty_state)\n    )\n    if training_state is None:\n        logger.error(\"No checkpoint found.\")\n        sys.exit(1)\n    logger.info(\"Restored checkpoint at step %d\", training_state.jit_state.step)\n\n    assert isinstance(training_state, TrainingState)\n\n    model, _ = nnx.split(\n        LczeroModel(config=config.model, rngs=nnx.Rngs(params=42))\n    )\n\n    datagen = from_dataloader(make_dataloader(config.data_loader))\n\n    # Prepare fixed validation batches only if requested (num_test_batches > 0).\n    use_validation = num_test_batches > 0\n    if use_validation:\n        logger.info(\"Fetching %d validation batches\", num_test_batches)\n        validation_batches = [\n            tree_util.tree_map(jnp.asarray, _prepare_batch(next(datagen)))\n            for _ in range(num_test_batches)\n        ]\n    else:\n        validation_batches = []\n\n    loss_fn = LczeroLoss(config=config.training.losses)\n    eval_step = _make_eval_step(model, loss_fn)\n\n    def avg_val_loss() -> float:\n        assert use_validation\n        total_loss = 0.0\n        for vb in validation_batches:\n            total_loss += float(\n                eval_step(training_state.jit_state.model_state, vb)\n            )\n        return total_loss / float(num_test_batches)\n\n    def train_one_step(\n        training: Training, tx: optax.GradientTransformation\n    ) -> float:\n        nonlocal training_state\n        batch = tree_util.tree_map(jnp.asarray, _prepare_batch(next(datagen)))\n        new_jit_state, metrics = training.train_step(\n            tx, training_state.jit_state, batch\n        )\n        training_state = training_state.replace(jit_state=new_jit_state)\n        return float(metrics[\"loss\"])  # training batch loss\n\n    def run_phase(\n        *,\n        steps: int,\n        schedule: optax.Schedule,\n        lr_at: Callable[[int], float],\n        label: str,\n        on_result: Callable[[float, float, float | None], None],\n    ) -> None:\n        start_step = training_state.jit_state.step\n\n        def offset_schedule(count: jax.Array) -> jax.Array:\n            return schedule(count - start_step)\n\n        tx = _make_optimizer_with_schedule(\n            training_state, config, offset_schedule\n        )\n        training = Training(optimizer_tx=tx, graphdef=model, loss_fn=loss_fn)\n        for i in range(steps):\n            current_lr = lr_at(i)\n            logger.info(\n                \"%s step %d/%d at lr %.8f\", label, i + 1, steps, current_lr\n            )\n            train_loss = train_one_step(training, tx)\n            val_loss = avg_val_loss() if use_validation else None\n            on_result(current_lr, train_loss, val_loss)\n            if use_validation:\n                logger.info(\n                    \"%s at lr %.8f: train=%.6f, val=%.6f\",\n                    label,\n                    current_lr,\n                    train_loss,\n                    cast(float, val_loss),\n                )\n            else:\n                logger.info(\n                    \"%s train loss at lr %.8f: %.6f\",\n                    label,\n                    current_lr,\n                    train_loss,\n                )\n\n    results: List[Tuple[float, float]] = []\n    with (\n        open(csv_output, \"w\", newline=\"\") if csv_output else nullcontext()\n    ) as csv_file:\n        writer = csv.writer(csv_file) if csv_file else None\n        if writer:\n            if use_validation:\n                writer.writerow([\"lr\", \"train_loss\", \"val_loss\"])\n            else:\n                writer.writerow([\"lr\", \"train_loss\"])\n\n        def on_result(\n            lr: float, train_loss: float, val_loss: float | None\n        ) -> None:\n            results.append((lr, train_loss))\n            if writer and csv_file:\n                if use_validation:\n                    writer.writerow([lr, train_loss, val_loss])\n                else:\n                    writer.writerow([lr, train_loss])\n                csv_file.flush()\n\n        phases = []\n        if warmup_steps > 0 and warmup_lr is not None:\n            phases.append(\n                {\n                    \"label\": \"Warmup\",\n                    \"steps\": warmup_steps,\n                    \"schedule\": optax.constant_schedule(warmup_lr),\n                    \"lr_at\": lambda _: float(warmup_lr),\n                }\n            )\n        phases.append(\n            {\n                \"label\": \"Sweep\",\n                \"steps\": num_steps,\n                \"schedule\": optax.exponential_decay(\n                    start_lr, transition_steps=1, decay_rate=multiplier\n                ),\n                \"lr_at\": lambda i: start_lr * (multiplier**i),\n            }\n        )\n\n        for phase_params in phases:\n            run_phase(**phase_params, on_result=on_result)\n\n    if plot_output:\n        _plot_results(results, plot_output)\n"
  },
  {
    "path": "src/lczero_training/training/utils.py",
    "content": "from pathlib import PurePosixPath\n\nfrom flax import nnx\n\nfrom proto.training_config_pb2 import WeightsSelector\n\n\ndef make_weights_mask(\n    selector: WeightsSelector, params: nnx.State\n) -> nnx.State:\n    \"\"\"Creates a boolean mask based on WeightsSelector. True = include weight.\"\"\"\n\n    def mask_fn(path: tuple[object, ...], _variable: nnx.Variable) -> bool:\n        p = PurePosixPath(*map(str, path))\n        for rule in selector.rule:\n            if p.full_match(rule.match):\n                return rule.include\n        return selector.otherwise_include\n\n    return nnx.map_state(mask_fn, params)\n"
  },
  {
    "path": "src/lczero_training/tui/__init__.py",
    "content": "# ABOUTME: TUI package initialization for the training dashboard.\n# ABOUTME: Exports main TrainingTuiApp class for external use.\n\nfrom .app import TrainingTuiApp\n\n__all__ = [\"TrainingTuiApp\"]\n"
  },
  {
    "path": "src/lczero_training/tui/app.py",
    "content": "# ABOUTME: Main TUI application class implementing the training dashboard.\n# ABOUTME: Uses Textual framework to create a full-screen interface with four panes.\n\nimport argparse\nimport logging\nimport os\nimport signal\nimport subprocess\nimport sys\nfrom datetime import datetime\nfrom typing import Iterable, Optional\n\nimport anyio\nfrom anyio.streams.text import TextReceiveStream, TextSendStream\nfrom textual.app import App, ComposeResult, SystemCommand\nfrom textual.containers import Horizontal\nfrom textual.screen import Screen\nfrom textual.widgets import Footer, Static\n\nfrom ..daemon.protocol.communicator import AsyncCommunicator\nfrom ..daemon.protocol.messages import (\n    StartTrainingImmediatelyPayload,\n    StartTrainingPayload,\n    TrainingStatusPayload,\n)\nfrom .data_pipeline_pane import DataPipelinePane\nfrom .log_pane import StreamingLogPane\nfrom .training_widgets import TrainingScheduleWidget\n\nlogger = logging.getLogger(__name__)\n\n\nclass HeaderBar(Static):\n    \"\"\"Empty header bar.\"\"\"\n\n    def compose(self) -> ComposeResult:\n        return\n        yield  # unreachable, but makes the function a generator\n\n\nclass JAXTrainingPane(Static):\n    \"\"\"Right pane showing JAX training status and metrics.\"\"\"\n\n    def compose(self) -> ComposeResult:\n        yield Static(\n            \"JAX Training Status\\n\\n\"\n            \"Live training metrics will be displayed here when active:\\n\"\n            \"• Epoch Progress\\n• Performance Metrics\\n• Loss Values\",\n            classes=\"jax-training-content\",\n        )\n\n\nclass TrainingTuiApp(App):\n    \"\"\"Main TUI application for the training dashboard.\n\n    This creates a full-screen interface with four main panes:\n    - Header bar with uptime and status\n    - Data pipeline pane (main/left)\n    - Training status pane (right)\n    - Log pane (bottom)\n    \"\"\"\n\n    CSS_PATH = \"app.tcss\"\n\n    @staticmethod\n    def add_arguments(parser: argparse.ArgumentParser) -> None:\n        \"\"\"Adds all required command-line arguments to the given parser.\"\"\"\n        parser.add_argument(\n            \"--config\",\n            required=True,\n            help=\"Path to the training configuration file\",\n        )\n        parser.add_argument(\n            \"--logfile\",\n            help=\"Path to the log file for saving TUI output\",\n        )\n        parser.add_argument(\n            \"--io-dump\",\n            help=\"Path to file for dumping raw daemon IO for debugging\",\n        )\n        parser.add_argument(\n            \"--daemon-flag\",\n            action=\"append\",\n            default=[],\n            dest=\"daemon_flags\",\n            help=\"Extra argument to pass to the daemon (repeatable)\",\n        )\n\n    _log_stream: TextReceiveStream\n    _daemon_process: anyio.abc.Process\n    _communicator: AsyncCommunicator\n    _config_file: str\n    _logfile: Optional[str]\n    _data_pipeline_pane: DataPipelinePane\n    _training_schedule_widget: TrainingScheduleWidget\n\n    BINDINGS = [\n        (\"q\", \"quit\", \"Quit\"),\n        (\"ctrl+c\", \"quit\", \"Quit\"),\n    ]\n\n    def __init__(self, args: Optional[argparse.Namespace] = None) -> None:\n        \"\"\"\n        Initializes the app.\n        If 'args' is provided, it's used directly.\n        If 'args' is None, fallback to parsing sys.argv.\n        \"\"\"\n        super().__init__()\n\n        if args is None:\n            # Fallback for when run by \"textual run\"\n            parser = argparse.ArgumentParser()\n            TrainingTuiApp.add_arguments(parser)\n            args, _ = parser.parse_known_args()\n\n        # Consume configuration from the args object\n        self._config_file: str = args.config\n        self._logfile: Optional[str] = args.logfile\n        self._io_dump_file: Optional[str] = args.io_dump\n        self._daemon_flags: list[str] = args.daemon_flags\n\n    async def on_load(self) -> None:\n        \"\"\"Start the daemon process and communicator when the app loads.\"\"\"\n        # Create the daemon process via Python module execution to avoid PATH reliance.\n        env = None\n        if \"TF_CPP_MIN_LOG_LEVEL\" not in os.environ:\n            env = {**os.environ, \"TF_CPP_MIN_LOG_LEVEL\": \"0\"}\n\n        self._daemon_process = await anyio.open_process(\n            [\n                sys.executable,\n                \"-m\",\n                \"lczero_training.commands.daemon\",\n                *self._daemon_flags,\n            ],\n            stdin=subprocess.PIPE,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.PIPE,\n            env=env,\n        )\n\n        assert self._daemon_process.stderr is not None\n        assert self._daemon_process.stdin is not None\n        assert self._daemon_process.stdout is not None\n\n        # Set up streams and communicator\n        self._log_stream = TextReceiveStream(self._daemon_process.stderr)\n        io_dump = None\n        if self._io_dump_file:\n            io_dump = open(self._io_dump_file, \"a\", buffering=1)\n            ts = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n            io_dump.write(f\"======= {ts} =======\\n\")\n        self._communicator = AsyncCommunicator(\n            handler=self,\n            input_stream=TextReceiveStream(self._daemon_process.stdout),\n            output_stream=TextSendStream(self._daemon_process.stdin),\n            io_dump=io_dump,\n        )\n\n    def compose(self) -> ComposeResult:\n        \"\"\"Compose the main UI layout.\"\"\"\n        yield HeaderBar()\n\n        self._data_pipeline_pane = DataPipelinePane()\n        self._data_pipeline_pane.border_title = \"Training data pipeline\"\n        yield self._data_pipeline_pane\n\n        # Horizontal split below the data pipeline pane\n        with Horizontal(id=\"training-status-container\"):\n            self._training_schedule_widget = TrainingScheduleWidget()\n            self._training_schedule_widget.border_title = \"Training Schedule\"\n            yield self._training_schedule_widget\n\n            jax_training_pane = JAXTrainingPane()\n            jax_training_pane.border_title = \"JAX Training Status\"\n            yield jax_training_pane\n\n        yield StreamingLogPane(\n            stream=self._log_stream, logfile_path=self._logfile\n        )\n        yield Footer()\n\n    async def _monitor_daemon_process(self) -> None:\n        \"\"\"Notify and log when the daemon process exits unexpectedly.\"\"\"\n        await self._daemon_process.wait()\n        rc = self._daemon_process.returncode\n        sig = -rc if rc is not None and rc < 0 else None\n        if sig is not None:\n            msg = f\"Daemon killed by signal {sig} ({signal.Signals(sig).name})\"\n        else:\n            msg = f\"Daemon exited with code {rc}\"\n        logger.warning(msg)\n        self.notify(msg, severity=\"warning\", timeout=60)\n        if self._logfile:\n            with open(self._logfile, \"a\") as f:\n                f.write(f\"{msg}\\n\")\n\n    def on_mount(self) -> None:\n        \"\"\"Start the communicator when the app mounts.\"\"\"\n        self.run_worker(self._communicator.run(), exclusive=True)\n        self.run_worker(self._send_start_training(), exclusive=False)\n        self.run_worker(self._monitor_daemon_process(), exclusive=False)\n\n    async def _send_start_training(self) -> None:\n        \"\"\"Send StartTrainingPayload with the config file.\"\"\"\n        payload = StartTrainingPayload(config_filepath=self._config_file)\n        await self._communicator.send(payload)\n\n    async def _command_start_training_immediately(self) -> None:\n        \"\"\"Trigger immediate training without waiting for additional chunks.\"\"\"\n\n        payload = StartTrainingImmediatelyPayload()\n        await self._communicator.send(payload)\n        self.notify(\"Requested immediate training start.\")\n\n    async def action_quit(self) -> None:  # type: ignore\n        \"\"\"Handle quit action.\"\"\"\n        self._daemon_process.send_signal(signal.SIGINT)\n        with anyio.move_on_after(20) as scope:\n            await self._daemon_process.wait()\n        if scope.cancelled_caught:\n            self._daemon_process.terminate()\n        self.exit()\n\n    def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]:\n        \"\"\"Add application-specific commands to the command palette.\"\"\"\n\n        yield from super().get_system_commands(screen)\n        yield SystemCommand(\n            \"Start training immediately\",\n            \"Begin the next training cycle without waiting for chunks.\",\n            self._command_start_training_immediately,\n        )\n\n    async def on_training_status(self, payload: TrainingStatusPayload) -> None:\n        \"\"\"Handle training status updates.\"\"\"\n        self._data_pipeline_pane.update_metrics(\n            payload.dataloader_1_second, payload.dataloader_total\n        )\n\n        # Update training schedule widget\n        self._training_schedule_widget.update_training_schedule(\n            payload.training_schedule\n        )\n"
  },
  {
    "path": "src/lczero_training/tui/app.tcss",
    "content": "Screen {\n    background: $primary;\n}\n\nHeaderBar {\n    dock: top;\n    height: 1;\n    background: $primary-darken-1;\n    color: $text;\n}\n\nDataPipelinePane {\n    background: $surface;\n    color: $text;\n    border: none;\n    margin: 1;\n    padding: 0 1;\n    layout: vertical;\n    overflow-y: auto;\n    height: 27;\n}\n\n#training-status-container {\n    height: 12;\n    layout: horizontal;\n}\n\nTrainingScheduleWidget {\n    background: $surface;\n    color: $text;\n    border: solid $primary;\n    margin: 1;\n    width: 1fr;\n}\n\nJAXTrainingPane {\n    background: $surface;\n    color: $text;\n    border: solid $primary;\n    margin: 1;\n    width: 1fr;\n}\n\nRichLog {\n    height: 1fr;\n    background: $panel;\n    color: $text;\n    margin: 0;\n    width: 100%;\n    overflow-y: scroll;\n}\n\n.stage-content, .queue-content {\n    text-align: left;\n}\n\n.dataloader-row {\n    layout: grid;\n    grid-size: 6 8;\n    grid-columns: 1fr 1fr 1fr 1fr 1fr 1fr;\n    grid-rows: 1;\n    grid-gutter: 0;\n    border: none;\n    padding: 0;\n    margin: 0;\n    height: auto;\n    min-height: 1;\n    width: 100%;\n}\n\n.stage-row {\n    background: $surface;\n    color: $text;\n    grid-size: 5 8;\n    grid-columns: 1fr 1fr 1fr 1fr 1fr;\n}\n\n.queue-row {\n    background: $surface-darken-2;\n    color: $text;\n    grid-size: 6 1;\n}\n\n.statistics-row {\n    background: $surface-darken-1;\n    color: $text;\n    grid-size: 1 1;\n}\n\n.statistics-full {\n    background: $surface-darken-1;\n    color: $text;\n    padding: 0 1;\n    margin: 0;\n    height: auto;\n    min-height: 1;\n    width: 100%;\n}\n\n.row-label {\n    padding: 0 1 0 0;\n    margin: 0;\n    height: 1;\n    max-height: 1;\n    text-style: bold;\n    color: $text;\n    content-align: left middle;\n    text-wrap: nowrap;\n    overflow: hidden;\n    text-overflow: ellipsis;\n}\n\n.grid-spacer {\n    width: 1;\n    height: 1;\n    max-height: 1;\n    background: transparent;\n}\n\n.metric-chip {\n    background: $surface-darken-1;\n    color: $text;\n    padding: 0 1;\n    margin: 0 1 0 0;\n    height: 1;\n    max-height: 1;\n    border: none;\n    content-align: center middle;\n    width: auto;\n    min-width: 10;\n    text-wrap: nowrap;\n    overflow: hidden;\n}\n\n.load-chip {\n    background: $primary-darken-2;\n    color: $text;\n}\n\n.info-chip {\n    background: $surface-darken-1;\n    color: $text;\n}\n\n.warning-chip {\n    background: $warning;\n    color: $text;\n}\n\n.queue-name-chip {\n    background: transparent;\n    color: $text-muted;\n    padding: 0;\n    margin-right: 1;\n    height: 1;\n    max-height: 1;\n    text-wrap: nowrap;\n    overflow: hidden;\n}\n\n.queue-fill {\n    height: 1;\n    max-height: 1;\n    width: 1fr;\n    margin-right: 1;\n}\n\n.queue-fill-text {\n    background: transparent;\n    color: $text;\n    margin: 0;\n    padding: 0;\n    height: 1;\n    max-height: 1;\n    text-wrap: nowrap;\n    overflow: hidden;\n    width: auto;\n}\n\n.queue-rate--zero {\n    color: $error;\n}\n\n.queue-rate, .queue-total {\n    text-wrap: nowrap;\n}\n\n\n.jax-training-content {\n    padding: 1;\n}\n\n/* Training widgets styles */\n.time-label {\n    width: auto;\n    padding: 0;\n    margin-right: 1;\n    text-align: left;\n}\n\n.time-progress-bar {\n    width: 1fr;\n    height: 1;\n}\n\n.time-progress-bar .bar--indeterminate {\n    background: $panel;\n}\n\n.time-progress-bar .bar--bar {\n    background: $panel;\n}\n\n.time-ratio {\n    width: auto;\n    padding: 0;\n    text-align: right;\n}\n\nTimeProgressWidget {\n    height: 1;\n    layout: horizontal;\n}\n\nChunksProgressWidget {\n    height: 1;\n    layout: horizontal;\n}\n\n#uptime-stage-display, #epochs-display {\n    height: 1;\n}\n\n.log-content {\n    padding: 0;\n}\n"
  },
  {
    "path": "src/lczero_training/tui/data_pipeline_pane.py",
    "content": "# ABOUTME: Data pipeline pane widget for displaying DataLoader metrics.\n# ABOUTME: Shows a grid of pipeline stages and queues with their metrics.\n\nfrom collections.abc import Iterable\nfrom typing import Any\n\nfrom textual.app import ComposeResult\nfrom textual.containers import Container\n\nimport proto.training_metrics_pb2 as training_metrics_pb2\n\nfrom .dataloader_widgets import QueueWidget, StageWidget, StatisticsRowWidget\n\nFRIENDLY_STAGE_NAMES = {\n    \"file_path_provider\": \"File discovery\",\n    \"chunk_source_loader\": \"Chunk source loader\",\n    \"shuffling_chunk_pool\": \"Shuffling chunk pool\",\n    \"chunk_rescorer\": \"Chunk rescorer\",\n    \"chunk_splitter\": \"Chunk splitter\",\n    \"chunk_unpacker\": \"Chunk unpacker\",\n    \"shuffling_frame_sampler\": \"Shuffling frame sampler\",\n    \"tensor_generator\": \"Batched tensor generator\",\n}\n\n\nclass DataPipelinePane(Container):\n    \"\"\"Main pane showing data pipeline flow and statistics as a grid.\"\"\"\n\n    def __init__(self, **kwargs: Any) -> None:\n        super().__init__(**kwargs)\n        self._stage_widgets: dict[str, StageWidget] = {}\n        self._queue_widgets: dict[str, dict[str, QueueWidget]] = {}\n        self._queue_order: dict[str, list[str]] = {}\n        self._statistics_widgets: dict[str, dict[str, StatisticsRowWidget]] = {}\n        self._statistics_order: dict[str, list[str]] = {}\n        self._stage_order: list[str] = []\n\n    def compose(self) -> ComposeResult:\n        \"\"\"The pane starts empty and rows are added when metrics arrive.\"\"\"\n        yield from ()\n\n    def _friendly_title(self, stage_key: str) -> str:\n        return FRIENDLY_STAGE_NAMES.get(\n            stage_key, stage_key.replace(\"_\", \" \").title()\n        )\n\n    def _ensure_stage_widget(\n        self,\n        stage_key: str,\n    ) -> tuple[StageWidget, bool]:\n        created = False\n        if stage_key not in self._stage_widgets:\n            stage_widget = StageWidget(\n                stage_key=stage_key,\n                fallback_name=self._friendly_title(stage_key),\n            )\n            self._stage_widgets[stage_key] = stage_widget\n            self._queue_widgets[stage_key] = {}\n            self._queue_order[stage_key] = []\n            self._stage_order.append(stage_key)\n            created = True\n        else:\n            stage_widget = self._stage_widgets[stage_key]\n\n        return stage_widget, created\n\n    def _ensure_queue_widgets(\n        self,\n        stage_key: str,\n        stage_metric: training_metrics_pb2.StageMetricProto,\n    ) -> list[QueueWidget]:\n        stage_queue_widgets = self._queue_widgets.setdefault(stage_key, {})\n        queue_order = self._queue_order.setdefault(stage_key, [])\n        new_widgets: list[QueueWidget] = []\n\n        friendly = self._friendly_title(stage_key)\n        for index, queue_metric in enumerate(stage_metric.queue_metrics):\n            queue_identifier = queue_metric.name or f\"__index__{index}\"\n            if queue_identifier in stage_queue_widgets:\n                continue\n            label_suffix = (\n                f\" {queue_metric.name}\"\n                if queue_metric.name\n                else f\" #{index + 1}\"\n            )\n            queue_widget = QueueWidget(\n                stage_key=stage_key,\n                stage_name=f\"{friendly} queue{label_suffix}\",\n                queue_name=queue_identifier,\n            )\n            stage_queue_widgets[queue_identifier] = queue_widget\n            queue_order.append(queue_identifier)\n            new_widgets.append(queue_widget)\n\n        return new_widgets\n\n    def _ensure_statistics_widgets(\n        self,\n        stage_key: str,\n        stage_metric: training_metrics_pb2.StageMetricProto,\n    ) -> list[StatisticsRowWidget]:\n        stage_statistics_widgets = self._statistics_widgets.setdefault(\n            stage_key, {}\n        )\n        statistics_order = self._statistics_order.setdefault(stage_key, [])\n        new_widgets: list[StatisticsRowWidget] = []\n\n        for statistics_metric in stage_metric.statistics_metrics:\n            metric_name = statistics_metric.name or \"\"\n            if metric_name in stage_statistics_widgets:\n                continue\n            label = metric_name.replace(\"_\", \" \").title()\n            statistics_widget = StatisticsRowWidget(\n                stage_key=stage_key,\n                metric_name=metric_name,\n                label=label,\n            )\n            stage_statistics_widgets[metric_name] = statistics_widget\n            statistics_order.append(metric_name)\n            new_widgets.append(statistics_widget)\n\n        return new_widgets\n\n    def _mount_widgets(\n        self, widgets: Iterable[StageWidget | QueueWidget | StatisticsRowWidget]\n    ) -> None:\n        async def _do_mount() -> None:\n            await self.mount(*widgets)\n\n        self.call_later(_do_mount)\n\n    def _ensure_rows(\n        self,\n        metrics: training_metrics_pb2.DataLoaderMetricsProto,\n    ) -> None:\n        new_widgets: list[StageWidget | QueueWidget | StatisticsRowWidget] = []\n        for stage_metric in metrics.stage_metrics:\n            stage_key = stage_metric.name\n            if not stage_key:\n                continue\n\n            stage_widget, created = self._ensure_stage_widget(stage_key)\n            if created:\n                new_widgets.append(stage_widget)\n\n            statistics_widgets = self._ensure_statistics_widgets(\n                stage_key, stage_metric\n            )\n            new_widgets.extend(statistics_widgets)\n\n            queue_widgets = self._ensure_queue_widgets(stage_key, stage_metric)\n            new_widgets.extend(queue_widgets)\n\n        if new_widgets:\n            self._mount_widgets(new_widgets)\n\n    def update_metrics(\n        self,\n        dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None,\n        dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None,\n    ) -> None:\n        \"\"\"Update all pipeline stages and queues with new metrics.\"\"\"\n\n        metrics_for_layout = dataloader_total or dataloader_1_second\n        if metrics_for_layout:\n            self._ensure_rows(metrics_for_layout)\n\n        for stage_key in self._stage_order:\n            stage_widget = self._stage_widgets.get(stage_key)\n            if stage_widget:\n                stage_widget.update_metrics(\n                    dataloader_1_second, dataloader_total\n                )\n\n            for statistics_key in self._statistics_order.get(stage_key, []):\n                statistics_widget = self._statistics_widgets[stage_key].get(\n                    statistics_key\n                )\n                if statistics_widget:\n                    statistics_widget.update_metrics(\n                        dataloader_1_second, dataloader_total\n                    )\n\n            for queue_key in self._queue_order.get(stage_key, []):\n                queue_widget = self._queue_widgets[stage_key].get(queue_key)\n                if queue_widget:\n                    queue_widget.update_metrics(\n                        dataloader_1_second, dataloader_total\n                    )\n"
  },
  {
    "path": "src/lczero_training/tui/dataloader_widgets.py",
    "content": "\"\"\"Widgets that render data loader metrics without stage-specific logic.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any, Dict\n\nfrom textual.app import ComposeResult\nfrom textual.widget import Widget\nfrom textual.widgets import ProgressBar, Static\n\nimport proto.training_metrics_pb2 as training_metrics_pb2\n\n\ndef _find_stage_metric(\n    metrics: training_metrics_pb2.DataLoaderMetricsProto | None,\n    stage_key: str,\n) -> training_metrics_pb2.StageMetricProto | None:\n    if not metrics:\n        return None\n    for stage_metric in metrics.stage_metrics:\n        if stage_metric.name == stage_key:\n            return stage_metric\n    return None\n\n\ndef _collect_metric_names(\n    stage_1s: training_metrics_pb2.StageMetricProto | None,\n    stage_total: training_metrics_pb2.StageMetricProto | None,\n    attribute: str,\n) -> list[str]:\n    names: list[str] = []\n\n    def _add_from(\n        stage_metric: training_metrics_pb2.StageMetricProto | None,\n    ) -> None:\n        if not stage_metric:\n            return\n        for metric in getattr(stage_metric, attribute):\n            name = metric.name if metric.name else \"\"\n            if name not in names:\n                names.append(name)\n\n    _add_from(stage_total)\n    _add_from(stage_1s)\n    return names\n\n\ndef _find_load_metric(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    metric_name: str,\n) -> training_metrics_pb2.LoadMetricProto | None:\n    if not stage_metric:\n        return None\n    for load_metric in stage_metric.load_metrics:\n        if (load_metric.name or \"\") == metric_name:\n            return load_metric\n    return None\n\n\ndef _find_count_metric(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    metric_name: str,\n) -> training_metrics_pb2.CountMetricProto | None:\n    if not stage_metric:\n        return None\n    for count_metric in stage_metric.count_metrics:\n        if (count_metric.name or \"\") == metric_name:\n            return count_metric\n    return None\n\n\ndef _find_gauge_metric(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    metric_name: str,\n) -> training_metrics_pb2.GaugeMetricProto | None:\n    if not stage_metric:\n        return None\n    for gauge_metric in stage_metric.gauge_metrics:\n        if (gauge_metric.name or \"\") == metric_name:\n            return gauge_metric\n    return None\n\n\ndef _find_statistics_metric(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    metric_name: str,\n) -> training_metrics_pb2.StatisticsProtoDouble | None:\n    if not stage_metric:\n        return None\n    for stats_metric in stage_metric.statistics_metrics:\n        if (stats_metric.name or \"\") == metric_name:\n            return stats_metric\n    return None\n\n\ndef _get_queue_metric(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    queue_name: str | None,\n) -> training_metrics_pb2.QueueMetricProto | None:\n    if not stage_metric or not stage_metric.queue_metrics:\n        return None\n    if queue_name is None:\n        return stage_metric.queue_metrics[0]\n    if queue_name.startswith(\"__index__\"):\n        try:\n            index = int(queue_name.removeprefix(\"__index__\"))\n        except ValueError:\n            index = -1\n        if 0 <= index < len(stage_metric.queue_metrics):\n            return stage_metric.queue_metrics[index]\n    for queue_metric in stage_metric.queue_metrics:\n        if (queue_metric.name or \"\") == queue_name:\n            return queue_metric\n    return None\n\n\ndef format_si(value: int, precision: int = 1) -> str:\n    if value == 0:\n        return \"0\"\n    units = [\n        (1_000_000_000_000, \"T\"),\n        (1_000_000_000, \"G\"),\n        (1_000_000, \"M\"),\n        (1_000, \"k\"),\n    ]\n    for threshold, unit in units:\n        if value >= threshold:\n            result = value / threshold\n            if precision == 0:\n                return f\"{int(result)}{unit}\"\n            return f\"{result:.{precision}f}{unit}\".rstrip(\"0\").rstrip(\".\")\n    return str(value)\n\n\ndef format_full_number(value: int) -> str:\n    if value < 10_000:\n        return str(value)\n    return f\"{value:_}\".replace(\"_\", \"'\")\n\n\ndef _format_load(\n    load_metric: training_metrics_pb2.LoadMetricProto | None,\n    label: str,\n) -> str:\n    if not load_metric:\n        return f\"{label} --\"\n    total_part = (\n        f\"{load_metric.total_seconds:.0f}\"\n        if load_metric.total_seconds > 0\n        else \"--\"\n    )\n    return f\"{label} {load_metric.load_seconds:.1f}/{total_part}s\"\n\n\ndef _format_count(\n    count_metric_1s: training_metrics_pb2.CountMetricProto | None,\n    count_metric_total: training_metrics_pb2.CountMetricProto | None,\n    label: str,\n) -> str:\n    if count_metric_1s and count_metric_total:\n        rate = format_si(count_metric_1s.count)\n        total = format_full_number(count_metric_total.count)\n        return f\"{label} {rate}/s ({total} total)\"\n    if count_metric_total:\n        return f\"{label} {format_full_number(count_metric_total.count)}\"\n    if count_metric_1s:\n        return f\"{label} {format_si(count_metric_1s.count)}/s\"\n    return f\"{label} --\"\n\n\ndef _format_gauge(\n    gauge_metric: training_metrics_pb2.GaugeMetricProto | None,\n    label: str,\n) -> str:\n    if not gauge_metric:\n        return f\"{label} --\"\n    if gauge_metric.HasField(\"capacity\"):\n        value_text = format_full_number(gauge_metric.value)\n        capacity_text = format_full_number(gauge_metric.capacity)\n        return f\"{label} {value_text}/{capacity_text}\"\n    return f\"{label} {format_full_number(gauge_metric.value)}\"\n\n\ndef _format_statistics(\n    stats_1s: training_metrics_pb2.StatisticsProtoDouble | None,\n    stats_total: training_metrics_pb2.StatisticsProtoDouble | None,\n    label: str,\n) -> str:\n    parts = [label]\n    if stats_1s and stats_1s.count > 0:\n        avg_1s = stats_1s.sum / stats_1s.count\n        parts.append(\n            f\"per 1s: avg {avg_1s:.1f} (min {stats_1s.min:.1f}, max {stats_1s.max:.1f}, count {format_si(stats_1s.count)})\"\n        )\n    if stats_total and stats_total.count > 0:\n        avg_total = stats_total.sum / stats_total.count\n        parts.append(\n            f\"total: avg {avg_total:.1f} (min {stats_total.min:.1f}, max {stats_total.max:.1f}, count {format_full_number(stats_total.count)})\"\n        )\n    if len(parts) == 1:\n        return f\"{label} --\"\n    return \"; \".join(parts)\n\n\ndef _average_queue_fullness(\n    queue_metric: training_metrics_pb2.QueueMetricProto | None,\n) -> int | None:\n    if not queue_metric:\n        return None\n    if (\n        queue_metric.HasField(\"queue_fullness\")\n        and queue_metric.queue_fullness.count > 0\n    ):\n        return int(\n            queue_metric.queue_fullness.sum / queue_metric.queue_fullness.count\n        )\n    return None\n\n\ndef _canonical_stage_name(\n    stage_metric: training_metrics_pb2.StageMetricProto | None,\n    fallback: str | None,\n    stage_key: str | None,\n) -> str:\n    if stage_metric and stage_metric.name:\n        return stage_metric.name\n    if fallback:\n        return fallback\n    if stage_key:\n        return stage_key\n    return \"--\"\n\n\nclass BaseRowWidget(Widget):\n    \"\"\"Base class for pipeline rows that renders a name and content widgets.\"\"\"\n\n    MAX_GRID_ROWS = 8  # Supports up to 28 chips (4 per row, 7 content rows)\n\n    def __init__(\n        self,\n        stage_key: str | None = None,\n        fallback_name: str | None = None,\n        row_type: str = \"stage-row\",\n        **kwargs: Any,\n    ) -> None:\n        classes = f\"dataloader-row {row_type}\"\n        super().__init__(classes=classes, **kwargs)\n        self.stage_key = stage_key\n        self._fallback_name = fallback_name\n        self._name_label = Static(\n            _canonical_stage_name(None, fallback_name, stage_key),\n            classes=\"row-label\",\n        )\n        self._content_widgets: list[Widget] = []\n        self._spacers: list[Static] = []\n\n    def compose(self) -> ComposeResult:\n        # Name label goes in first cell (row 0, col 0)\n        yield self._name_label\n\n    def on_mount(self) -> None:\n        if self._content_widgets:\n\n            async def _mount_initial() -> None:\n                # Mount chips with spacers interleaved at the start of each row\n                for i, widget in enumerate(self._content_widgets):\n                    # After every 4 chips, add a spacer for the next row's column 0\n                    if i > 0 and i % 4 == 0:\n                        spacer = Static(\"\", classes=\"grid-spacer\")\n                        self._spacers.append(spacer)\n                        await self.mount(spacer)\n                    await self.mount(widget)\n\n            self.call_later(_mount_initial)\n\n    def add_content_widget(self, widget: Widget) -> None:\n        if widget in self._content_widgets:\n            return\n        self._content_widgets.append(widget)\n\n    def _update_name(\n        self,\n        stage_metric: training_metrics_pb2.StageMetricProto | None,\n    ) -> None:\n        self._name_label.update(\n            _canonical_stage_name(\n                stage_metric, self._fallback_name, self.stage_key\n            )\n        )\n\n\nclass StageWidget(BaseRowWidget):\n    \"\"\"Row widget that renders all metrics exposed by a stage.\"\"\"\n\n    def __init__(\n        self,\n        stage_key: str,\n        fallback_name: str | None = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            stage_key=stage_key,\n            fallback_name=fallback_name,\n            row_type=\"stage-row\",\n            **kwargs,\n        )\n        self._chips: Dict[str, Static] = {}\n\n    def _ensure_chip(self, key: str, default_text: str, classes: str) -> Static:\n        chip = self._chips.get(key)\n        if chip is None:\n            chip = Static(default_text, classes=f\"metric-chip {classes}\")\n            self._chips[key] = chip\n            self.add_content_widget(chip)\n        return chip\n\n    def _update_last_chunk_chip(\n        self,\n        stage_1s: training_metrics_pb2.StageMetricProto | None,\n        stage_total: training_metrics_pb2.StageMetricProto | None,\n    ) -> None:\n        stage = None\n        if stage_1s and stage_1s.HasField(\"last_chunk_key\"):\n            stage = stage_1s\n        elif stage_total and stage_total.HasField(\"last_chunk_key\"):\n            stage = stage_total\n        if not stage:\n            return\n        last_value = stage.last_chunk_key or \"--\"\n        chip = self._ensure_chip(\"info:last\", \"last --\", \"info-chip\")\n        chip.update(f\"last {last_value}\")\n\n    def _update_anchor_chip(\n        self,\n        stage_total: training_metrics_pb2.StageMetricProto | None,\n    ) -> None:\n        if not stage_total or not stage_total.HasField(\"anchor\"):\n            return\n        chip = self._ensure_chip(\"info:anchor\", \"anchor --\", \"info-chip\")\n        chip.update(f\"anchor {stage_total.anchor}\")\n\n    def update_metrics(\n        self,\n        dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None,\n        dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None,\n    ) -> None:\n        if self.stage_key is None:\n            return\n        stage_metric_1s = _find_stage_metric(\n            dataloader_1_second, self.stage_key\n        )\n        stage_metric_total = _find_stage_metric(\n            dataloader_total, self.stage_key\n        )\n\n        self._update_name(stage_metric_1s or stage_metric_total)\n\n        load_names = _collect_metric_names(\n            stage_metric_1s, stage_metric_total, \"load_metrics\"\n        )\n        for load_name in load_names:\n            label = load_name or \"load\"\n            load_metric = _find_load_metric(stage_metric_1s, load_name)\n            if load_metric is None:\n                load_metric = _find_load_metric(stage_metric_total, load_name)\n            chip = self._ensure_chip(\n                f\"load:{load_name}\", f\"{label} --\", \"load-chip\"\n            )\n            chip.update(_format_load(load_metric, label=label))\n\n        count_names = _collect_metric_names(\n            stage_metric_1s, stage_metric_total, \"count_metrics\"\n        )\n        for count_name in count_names:\n            label = count_name or \"count\"\n            count_metric_1s = _find_count_metric(stage_metric_1s, count_name)\n            count_metric_total = _find_count_metric(\n                stage_metric_total, count_name\n            )\n            chip = self._ensure_chip(\n                f\"count:{count_name}\", f\"{label} --\", \"info-chip\"\n            )\n            chip.update(\n                _format_count(count_metric_1s, count_metric_total, label=label)\n            )\n\n        gauge_names = _collect_metric_names(\n            stage_metric_1s, stage_metric_total, \"gauge_metrics\"\n        )\n        for gauge_name in gauge_names:\n            label = gauge_name or \"gauge\"\n            gauge_metric = _find_gauge_metric(stage_metric_total, gauge_name)\n            if gauge_metric is None:\n                gauge_metric = _find_gauge_metric(stage_metric_1s, gauge_name)\n            chip = self._ensure_chip(\n                f\"gauge:{gauge_name}\", f\"{label} --\", \"info-chip\"\n            )\n            chip.update(_format_gauge(gauge_metric, label=label))\n\n        self._update_last_chunk_chip(stage_metric_1s, stage_metric_total)\n        self._update_anchor_chip(stage_metric_total)\n\n\nclass StatisticsRowWidget(BaseRowWidget):\n    \"\"\"Full-width row for statistics metrics.\"\"\"\n\n    def __init__(\n        self,\n        stage_key: str,\n        metric_name: str,\n        label: str,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            stage_key=stage_key,\n            fallback_name=None,\n            row_type=\"statistics-row\",\n            **kwargs,\n        )\n        self._metric_name = metric_name\n        self._label = label\n        self._stats_label = Static(f\"{label} --\", classes=\"statistics-full\")\n        self.add_content_widget(self._stats_label)\n\n    def compose(self) -> ComposeResult:\n        return\n        yield\n\n    def on_mount(self) -> None:\n        if self._content_widgets:\n\n            async def _mount_initial() -> None:\n                await self.mount(*self._content_widgets)\n\n            self.call_later(_mount_initial)\n\n    def update_metrics(\n        self,\n        dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None,\n        dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None,\n    ) -> None:\n        if not self.stage_key:\n            self._stats_label.update(f\"{self._label} --\")\n            return\n        stage_1s = _find_stage_metric(dataloader_1_second, self.stage_key)\n        stage_total = _find_stage_metric(dataloader_total, self.stage_key)\n        stats_1s = _find_statistics_metric(stage_1s, self._metric_name)\n        stats_total = _find_statistics_metric(stage_total, self._metric_name)\n        self._stats_label.update(\n            _format_statistics(stats_1s, stats_total, self._label)\n        )\n\n\nclass QueueWidget(BaseRowWidget):\n    \"\"\"Row widget for queue metrics between stages.\"\"\"\n\n    def __init__(\n        self,\n        stage_key: str | None = None,\n        stage_name: str | None = None,\n        queue_name: str | None = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            stage_key=stage_key,\n            fallback_name=stage_name,\n            row_type=\"queue-row\",\n            **kwargs,\n        )\n        self._queue_name = queue_name\n        self._queue_name_chip = Static(\n            \"queue --\", classes=\"metric-chip queue-name-chip\"\n        )\n        self._rate_chip = Static(\"rate --/s\", classes=\"metric-chip queue-rate\")\n        self._total_chip = Static(\"total --\", classes=\"metric-chip queue-total\")\n        self._drop_chip = Static(\"dropped --\", classes=\"metric-chip\")\n        self._fill_bar = ProgressBar(\n            classes=\"queue-fill\",\n            show_percentage=False,\n            show_eta=False,\n        )\n        self._fill_text = Static(\"--/--\", classes=\"metric-chip queue-fill-text\")\n        self.add_content_widget(self._queue_name_chip)\n        self.add_content_widget(self._rate_chip)\n        self.add_content_widget(self._total_chip)\n        self.add_content_widget(self._drop_chip)\n        self.add_content_widget(self._fill_bar)\n        self.add_content_widget(self._fill_text)\n\n    def compose(self) -> ComposeResult:\n        # Queue widgets don't show the name label - use all 6 columns\n        return\n        yield  # Make this a generator\n\n    def on_mount(self) -> None:\n        # Mount all content widgets without spacers (use all 6 columns per row)\n        if self._content_widgets:\n\n            async def _mount_initial() -> None:\n                await self.mount(*self._content_widgets)\n\n            self.call_later(_mount_initial)\n\n    def update_metrics(\n        self,\n        dataloader_1_second: training_metrics_pb2.DataLoaderMetricsProto | None,\n        dataloader_total: training_metrics_pb2.DataLoaderMetricsProto | None,\n    ) -> None:\n        if not self.stage_key:\n            self._queue_name_chip.update(\"queue --\")\n            self._rate_chip.update(\"rate --/s\")\n            self._total_chip.update(\"total --\")\n            self._drop_chip.update(\"dropped --\")\n            self._drop_chip.remove_class(\"warning-chip\")\n            self._fill_bar.total = 1\n            self._fill_bar.progress = 0\n            self._fill_text.update(\"--/--\")\n            return\n\n        stage_1sec = _find_stage_metric(dataloader_1_second, self.stage_key)\n        stage_total = _find_stage_metric(dataloader_total, self.stage_key)\n        self._update_name(stage_1sec or stage_total)\n\n        queue_1sec = _get_queue_metric(stage_1sec, self._queue_name)\n        queue_total = _get_queue_metric(stage_total, self._queue_name)\n\n        queue_name = None\n        if queue_1sec and queue_1sec.name:\n            queue_name = queue_1sec.name\n        elif queue_total and queue_total.name:\n            queue_name = queue_total.name\n        elif self._queue_name:\n            queue_name = self._queue_name\n        self._queue_name_chip.update(\n            f\"queue {queue_name}\" if queue_name else \"queue --\"\n        )\n\n        rate = queue_1sec.get_count if queue_1sec else 0\n        self._rate_chip.update(f\"rate {format_si(rate)}/s\")\n        if rate == 0:\n            self._rate_chip.add_class(\"queue-rate--zero\")\n        else:\n            self._rate_chip.remove_class(\"queue-rate--zero\")\n\n        if queue_total:\n            self._total_chip.update(\n                f\"total {format_full_number(queue_total.get_count)}\"\n            )\n        else:\n            self._total_chip.update(\"total --\")\n\n        has_drops = False\n        if queue_1sec and queue_1sec.drop_count:\n            self._drop_chip.update(\n                f\"dropped {format_si(queue_1sec.drop_count)}/s\"\n            )\n            has_drops = True\n        elif queue_total and queue_total.drop_count:\n            self._drop_chip.update(\n                f\"dropped {format_full_number(queue_total.drop_count)}\"\n            )\n            has_drops = True\n        else:\n            self._drop_chip.update(\"dropped --\")\n            has_drops = False\n\n        if has_drops:\n            self._drop_chip.add_class(\"warning-chip\")\n        else:\n            self._drop_chip.remove_class(\"warning-chip\")\n\n        size = _average_queue_fullness(queue_1sec)\n        capacity: int | None = None\n        if queue_1sec and queue_1sec.queue_capacity > 0:\n            capacity = queue_1sec.queue_capacity\n        elif queue_total and queue_total.queue_capacity > 0:\n            capacity = queue_total.queue_capacity\n\n        if capacity and capacity > 0:\n            self._fill_bar.total = capacity\n            if size is not None:\n                self._fill_bar.progress = min(size, capacity)\n                fill_text = (\n                    f\"{format_full_number(size)}/{format_full_number(capacity)}\"\n                )\n            else:\n                self._fill_bar.progress = 0\n                fill_text = f\"--/{format_full_number(capacity)}\"\n        else:\n            self._fill_bar.total = 1\n            self._fill_bar.progress = 0\n            fill_text = \"--/--\"\n        self._fill_text.update(fill_text)\n"
  },
  {
    "path": "src/lczero_training/tui/log_pane.py",
    "content": "import datetime\nfrom pathlib import Path\nfrom typing import Any, Optional, TextIO\n\nfrom anyio.streams.text import TextReceiveStream\nfrom textual.widgets import RichLog\n\n\nclass StreamingLogPane(RichLog):\n    \"\"\"Log pane that streams output from an async text stream.\"\"\"\n\n    def __init__(\n        self,\n        stream: TextReceiveStream,\n        logfile_path: Optional[str] = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            highlight=True, markup=True, max_lines=1000, wrap=True, **kwargs\n        )\n        self._stream = stream\n        self._logfile_path = logfile_path\n        self._logfile: Optional[Path] = None\n        self._logfile_handle: Optional[TextIO] = None\n        if logfile_path:\n            self._logfile = Path(logfile_path)\n\n    def on_mount(self) -> None:\n        \"\"\"Start the async reading task when the widget is mounted.\"\"\"\n        if self._logfile:\n            self._write_banner()\n        self.run_worker(self._read_stream())\n\n    def _write_banner(self) -> None:\n        \"\"\"Write a session banner to the logfile.\"\"\"\n        if not self._logfile:\n            return\n\n        self._logfile.parent.mkdir(parents=True, exist_ok=True)\n        self._logfile_handle = self._logfile.open(\"a\", encoding=\"utf-8\")\n        timestamp = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n        banner = (\n            f\"\\n{'=' * 80}\\n\"\n            f\"LCZero Training TUI Session Started: {timestamp}\\n\"\n            f\"{'=' * 80}\\n\"\n        )\n        self._logfile_handle.write(banner)\n        self._logfile_handle.flush()\n\n    def _write_to_file(self, line: str) -> None:\n        \"\"\"Write a line to the logfile.\"\"\"\n        if not self._logfile_handle:\n            return\n\n        self._logfile_handle.write(f\"{line}\\n\")\n        self._logfile_handle.flush()\n\n    async def _read_stream(self) -> None:\n        \"\"\"Async function that reads lines from the text stream.\"\"\"\n        try:\n            async for line in self._stream:\n                line = line.strip()\n                if line:\n                    self.write(line)\n                    self._write_to_file(line)\n        except Exception:\n            pass\n"
  },
  {
    "path": "src/lczero_training/tui/training_widgets.py",
    "content": "from textual.app import ComposeResult\nfrom textual.widgets import ProgressBar, Static\n\nfrom ..daemon.protocol.messages import TrainingScheduleData\n\n\nclass TimeProgressWidget(Static):\n    \"\"\"A widget to display a label, progress bar, and time ratio.\"\"\"\n\n    def __init__(self, label: str, *, id: str | None = None) -> None:\n        super().__init__(id=id)\n        self._label = label\n\n    def compose(self) -> ComposeResult:\n        yield Static(self._label, classes=\"time-label\")\n        yield ProgressBar(show_eta=False, classes=\"time-progress-bar\")\n        yield Static(\"\", classes=\"time-ratio\")\n\n    def update_progress(\n        self,\n        current: float,\n        total: float,\n        current_formatted: str | None = None,\n        total_formatted: str | None = None,\n    ) -> None:\n        \"\"\"Update the progress bar and ratio text.\"\"\"\n        progress_bar = self.query_one(ProgressBar)\n\n        progress_bar.total = total or None\n        progress_bar.progress = current\n\n        current_str = (\n            current_formatted\n            if current_formatted is not None\n            else str(int(current))\n        )\n        total_str = (\n            total_formatted if total_formatted is not None else str(int(total))\n        )\n\n        self.query_one(\".time-ratio\", Static).update(\n            f\"{current_str}/{total_str}\"\n        )\n\n\ndef format_time_duration(seconds: float) -> str:\n    \"\"\"Format time duration in seconds to human readable format with days support.\"\"\"\n    if seconds <= 0:\n        return \"--\"\n\n    total_seconds = int(seconds)\n    days, remainder = divmod(total_seconds, 86400)\n    hours, remainder = divmod(remainder, 3600)\n    minutes, secs = divmod(remainder, 60)\n\n    if days > 0:\n        return f\"{days}d {hours:02d}:{minutes:02d}:{secs:02d}\"\n    if hours > 0:\n        return f\"{hours:02d}:{minutes:02d}:{secs:02d}\"\n    return f\"{minutes:02d}:{secs:02d}\"\n\n\nclass TrainingScheduleWidget(Static):\n    \"\"\"A widget to display training schedule information.\"\"\"\n\n    def compose(self) -> ComposeResult:\n        yield Static(\"Uptime: --   Stage: --\", id=\"uptime-stage-display\")\n        yield Static(\"Completed epochs: 0\", id=\"epochs-display\")\n        yield TimeProgressWidget(\"New Chunks:\", id=\"chunks-progress\")\n        yield TimeProgressWidget(\"Training time\", id=\"training-time-progress\")\n        yield TimeProgressWidget(\"Cycle time\", id=\"cycle-time-progress\")\n\n    def update_training_schedule(\n        self, data: TrainingScheduleData | None\n    ) -> None:\n        \"\"\"Update the widget with new training schedule data.\"\"\"\n        if not data:\n            return\n\n        uptime_str = format_time_duration(data.total_uptime_seconds)\n        self.query_one(\"#uptime-stage-display\", Static).update(\n            f\"Uptime: {uptime_str}   Stage: {data.current_stage.value}\"\n        )\n\n        self.query_one(\"#epochs-display\", Static).update(\n            f\"Completed epochs: {data.completed_epochs_since_start}\"\n        )\n\n        self.query_one(\"#chunks-progress\", TimeProgressWidget).update_progress(\n            current=data.new_chunks_since_training_start,\n            total=data.chunks_to_wait,\n        )\n\n        self.query_one(\n            \"#training-time-progress\", TimeProgressWidget\n        ).update_progress(\n            current=data.current_training_time_seconds,\n            total=data.previous_training_time_seconds,\n            current_formatted=format_time_duration(\n                data.current_training_time_seconds\n            ),\n            total_formatted=format_time_duration(\n                data.previous_training_time_seconds\n            ),\n        )\n\n        self.query_one(\n            \"#cycle-time-progress\", TimeProgressWidget\n        ).update_progress(\n            current=data.current_cycle_time_seconds,\n            total=data.previous_cycle_time_seconds,\n            current_formatted=format_time_duration(\n                data.current_cycle_time_seconds\n            ),\n            total_formatted=format_time_duration(\n                data.previous_cycle_time_seconds\n            ),\n        )\n"
  },
  {
    "path": "src/proto/__init__.py",
    "content": ""
  },
  {
    "path": "tf/attention_policy_map.py",
    "content": "import numpy as np\n\n\nmove = np.arange(1, 8)\n\ndiag = np.array([\n    move    + move*8,\n    move    - move*8,\n    move*-1 - move*8,\n    move*-1 + move*8\n])\n\northog = np.array([\n    move,\n    move*-8,\n    move*-1,\n    move*8\n])\n\nknight = np.array([\n    [2 + 1*8],\n    [2 - 1*8],\n    [1 - 2*8],\n    [-1 - 2*8],\n    [-2 - 1*8],\n    [-2 + 1*8],\n    [-1 + 2*8],\n    [1 + 2*8]\n])\n\npromos = np.array([2*8, 3*8, 4*8])\npawn_promotion = np.array([\n    -1 + promos,\n    0 + promos,\n    1 + promos\n])\n\n\ndef make_map():\n    \"\"\"theoretically possible put-down squares (numpy array) for each pick-up square (list element).\n    squares are [0, 1, ..., 63] for [a1, b1, ..., h8]. squares after 63 are promotion squares.\n    each successive \"row\" beyond 63 (ie. 64:72, 72:80, 80:88) are for over-promotions to queen, rook, and bishop;\n    respectively. a pawn traverse to row 56:64 signifies a \"default\" promotion to a knight.\"\"\"\n    traversable = []\n    for i in range(8):\n        for j in range(8):\n            sq = (8*i + j)\n            traversable.append(\n                sq +\n                np.sort(\n                    np.int32(\n                        np.concatenate((\n                            orthog[0][:7-j], orthog[2][:j], orthog[1][:i], orthog[3][:7-i],\n                            diag[0][:np.min((7-i, 7-j))], diag[3][:np.min((7-i, j))],\n                            diag[1][:np.min((i, 7-j))], diag[2][:np.min((i, j))],\n                            knight[0] if i < 7 and j < 6 else [], knight[1] if i > 0 and j < 6 else [],\n                            knight[2] if i > 1 and j < 7 else [], knight[3] if i > 1 and j > 0 else [],\n                            knight[4] if i > 0 and j > 1 else [], knight[5] if i < 7 and j > 1 else [],\n                            knight[6] if i < 6 and j > 0 else [], knight[7] if i < 6 and j < 7 else [],\n                            pawn_promotion[0] if i == 6 and j > 0 else [],\n                            pawn_promotion[1] if i == 6           else [],\n                            pawn_promotion[2] if i == 6 and j < 7 else [],\n                        ))\n                    )\n                )\n            )\n    z = np.zeros((64*64+8*24, 1858), dtype=np.int32)\n    # first loop for standard moves (for i in 0:1858, stride by 1)\n    i = 0\n    for pickup_index, putdown_indices in enumerate(traversable):\n        for putdown_index in putdown_indices:\n            if putdown_index < 64:\n                z[putdown_index + (64*pickup_index), i] = 1\n                i += 1\n    # second loop for promotions (for i in 1792:1858, stride by ls[j])\n    j = 0\n    j1 = np.array([3, -2, 3, -2, 3])\n    j2 = np.array([3, 3, -5, 3, 3, -5, 3, 3, 1])\n    ls = np.append(j1, 1)\n    for k in range(6):\n        ls = np.append(ls, j2)\n    ls = np.append(ls, j1)\n    ls = np.append(ls, 0)\n    for pickup_index, putdown_indices in enumerate(traversable):\n        for putdown_index in putdown_indices:\n            if putdown_index >= 64:\n                pickup_file = pickup_index % 8\n                promotion_file = putdown_index % 8\n                promotion_rank = (putdown_index // 8) - 8\n                z[4096 + pickup_file*24 + (promotion_file*3+promotion_rank), i] = 1\n                i += ls[j]\n                j += 1\n\n    return z\n\ndef make_pos_enc():\n    traversable = []\n    for i in range(8):\n        for j in range(8):\n            sq = (8*i + j)\n            traversable.append(\n                sq +\n                np.sort(\n                    np.int32(\n                        np.concatenate((\n                            orthog[0][:7-j], orthog[2][:j], orthog[1][:i], orthog[3][:7-i],\n                            diag[0][:np.min((7-i, 7-j))], diag[3][:np.min((7-i, j))],\n                            diag[1][:np.min((i, 7-j))], diag[2][:np.min((i, j))],\n                            knight[0] if i < 7 and j < 6 else [], knight[1] if i > 0 and j < 6 else [],\n                            knight[2] if i > 1 and j < 7 else [], knight[3] if i > 1 and j > 0 else [],\n                            knight[4] if i > 0 and j > 1 else [], knight[5] if i < 7 and j > 1 else [],\n                            knight[6] if i < 6 and j > 0 else [], knight[7] if i < 6 and j < 7 else [],\n                            # pawn_promotion[0] if i == 6 and j > 0 else [],\n                            # pawn_promotion[1] if i == 6           else [],\n                            # pawn_promotion[2] if i == 6 and j < 7 else [],\n                        ))\n                    )\n                )\n            )\n\n    # pos_enc = np.zeros((1, 64, 88), dtype=np.float32)\n    pos_enc = np.zeros((1, 64, 64), dtype=np.float32)\n    for i, k in enumerate(traversable):\n        pos_enc[0][i][i] = -1.\n        for j in k:\n            pos_enc[0][i][j] = 1.\n\n    return pos_enc\n"
  },
  {
    "path": "tf/chunkparsefunc.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Chess.\n#    Copyright (C) 2021 Leela Chess Authors\n#\n#    Leela Chess is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Chess is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\nimport tensorflow as tf\n\n\ndef parse_function(planes, probs, winner, q, plies_left):\n    \"\"\"\n    Convert unpacked record batches to tensors for tensorflow training\n    \"\"\"\n    planes = tf.io.decode_raw(planes, tf.float32)\n    probs = tf.io.decode_raw(probs, tf.float32)\n    winner = tf.io.decode_raw(winner, tf.float32)\n    q = tf.io.decode_raw(q, tf.float32)\n    plies_left = tf.io.decode_raw(plies_left, tf.float32)\n\n    planes = tf.reshape(planes, (-1, 112, 8, 8))\n    probs = tf.reshape(probs, (-1, 1858))\n    winner = tf.reshape(winner, (-1, 3))\n    q = tf.reshape(q, (-1, 3))\n    plies_left = tf.reshape(plies_left, (-1, 1))\n\n    return (planes, probs, winner, q, plies_left)\n"
  },
  {
    "path": "tf/chunkparser.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Chess.\n#    Copyright (C) 2018 Folkert Huizinga\n#    Copyright (C) 2017-2018 Gian-Carlo Pascutto\n#\n#    Leela Chess is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Chess is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\n\"\"\"\nGeneral comments on how chunkparser works.\n\nA \"training record\" or just \"record\" is a fixed-length packed byte array. Typically\nrecords are generated during training and are stored together by game, one record for \neach position in the game, but this arrangement is not required.\nOver dev time additional fields have been added to the training record, most of which \njust put additional information after the end of the byte array used in the previous \nversion. Currently supported training record versions are V3, V4, V5, and V6.\n\nshufflebuffer.ShuffleBuffer is a simple structure holding an array of training\nrecords that are efficiently randomized and replaced as needed. All records in\nShuffleBuffer are adjusted to be the same number of bytes by appending unused \nbytes *before* being put in the shuffler. \nbyte padding is done in chunkparser.ChunkParser.sample_record()\nsample_record() also skips most training records to avoid sampling over-correlated\npositions since they typically are from sequential positions in a game.\n\nCurrent implementation of \"diff focus\" also is in sample_record() and works by\nprobabilistically skipping records according to how accurate the no-search \neval ('orig_q') is compared to eval after search ('best_q') as well as the\nrecorded policy_kld (a measure of difference between no search policy and the\nfinal policy distribution). It does not use draw values at this point. Putting\ndiff focus here is efficient because it runs in parallel workers and peeks at\nthe records without requiring any unpacking.\n\nThe constructor for chunkparser.ChunkParser() sets a bunch of class constants\nand creates a fixed number of parallel Python multiprocessing.Pipe objects,\nwhich consist of a \"reader\" and a \"writer\". The writer(s) get data directly\nfrom training data files and write them into the pipe using the writer.send_bytes()\nmethod. The reader(s) get data out of the pipe using the reader.rev_bytes()\nmethod and feed them to the ShuffleBuffer using its insert_or_replace() method,\nwhich also handles the shuffling itself.\n\nRecords come back out of the ShuffleBuffer (already a fixed byte number\nregardless of training version) using the multiplexed generators specified in\nthe ChunkParser.parse() method. They are first recovered as raw byte records\nin the vX_gen() method (currently v6_gen), then converted to tuples of more\ninterpretable data in the convert_vX_to_tuple() method and finally sent on\nto tensorflow in training batches by the batch_gen() method.\n\"\"\"\n\nimport itertools\nimport multiprocessing as mp\nimport numpy as np\nimport random\nimport shufflebuffer as sb\nimport struct\nimport unittest\nimport gzip\nfrom select import select\n\nV6_VERSION = struct.pack('i', 6)\nV5_VERSION = struct.pack('i', 5)\nCLASSICAL_INPUT = struct.pack('i', 1)\nV4_VERSION = struct.pack('i', 4)\nV3_VERSION = struct.pack('i', 3)\nV6_STRUCT_STRING = '4si7432s832sBBBBBBBbfffffffffffffffIHH4H'\nV5_STRUCT_STRING = '4si7432s832sBBBBBBBbfffffff'\nV4_STRUCT_STRING = '4s7432s832sBBBBBBBbffff'\nV3_STRUCT_STRING = '4s7432s832sBBBBBBBb'\n\n\ndef reverse_expand_bits(plane):\n    return np.unpackbits(np.array([plane], dtype=np.uint8))[::-1].astype(\n        np.float32).tobytes()\n\n\n# Interface for a chunk data source.\nclass ChunkDataSrc:\n    def __init__(self, items):\n        self.items = items\n\n    def next(self):\n        if not self.items:\n            return None\n        return self.items.pop()\n\n\ndef chunk_reader(chunk_filenames, chunk_filename_queue):\n    \"\"\"\n    Reads chunk filenames from a list and writes them in shuffled\n    order to output_pipes.\n    \"\"\"\n    chunks = []\n    done = chunk_filenames\n\n    while True:\n        if not chunks:\n            chunks, done = done, chunks\n            random.shuffle(chunks)\n        if not chunks:\n            print(\"chunk_reader didn't find any chunks.\")\n            return None\n        while len(chunks):\n            filename = chunks.pop()\n            done.append(filename)\n            chunk_filename_queue.put(filename)\n    print(\"chunk_reader exiting.\")\n    return None\n\n\nclass ChunkParser:\n\n    def __init__(self,\n                 chunks,\n                 expected_input_format,\n                 shuffle_size=1,\n                 sample=1,\n                 buffer_size=1,\n                 batch_size=256,\n                 diff_focus_min=1,\n                 diff_focus_slope=0,\n                 diff_focus_q_weight=6.0,\n                 diff_focus_pol_scale=3.5,\n                 workers=None):\n        self.inner = ChunkParserInner(self, chunks, expected_input_format,\n                                      shuffle_size, sample, buffer_size,\n                                      batch_size, diff_focus_min,\n                                      diff_focus_slope, diff_focus_q_weight,\n                                      diff_focus_pol_scale, workers)\n\n    def shutdown(self):\n        \"\"\"\n        Terminates all the workers\n        \"\"\"\n        for i in range(len(self.processes)):\n            self.processes[i].terminate()\n            self.processes[i].join()\n            self.inner.readers[i].close()\n            self.inner.writers[i].close()\n        self.chunk_process.terminate()\n        self.chunk_process.join()\n\n    def parse(self):\n        return self.inner.parse()\n\n    def sequential(self):\n        return self.inner.sequential()\n\n\nclass ChunkParserInner:\n    def __init__(self, parent, chunks, expected_input_format, shuffle_size,\n                 sample, buffer_size, batch_size, diff_focus_min,\n                 diff_focus_slope, diff_focus_q_weight, diff_focus_pol_scale,\n                 workers):\n        \"\"\"\n        Read data and yield batches of raw tensors.\n\n        'parent' the outer chunk parser to store processes. Must not be stored by self directly or indirectly.\n        'chunks' list of chunk filenames.\n        'shuffle_size' is the size of the shuffle buffer.\n        'sample' is the rate to down-sample.\n        'diff_focus_min', 'diff_focus_slope', 'diff_focus_q_weight' and 'diff_focus_pol_scale' control diff focus\n        'workers' is the number of child workers to use.\n\n        The data is represented in a number of formats through this dataflow\n        pipeline. In order, they are:\n\n        chunk: The name of a file containing chunkdata\n\n        chunkdata: type Bytes. Multiple records of v6 format where each record\n        consists of (state, policy, result, q)\n\n        raw: A byte string holding raw tensors contenated together. This is\n        used to pass data from the workers to the parent. Exists because\n        TensorFlow doesn't have a fast way to unpack bit vectors. 7950 bytes\n        long.\n        \"\"\"\n\n        self.expected_input_format = expected_input_format\n\n        # Build 2 flat float32 planes with values 0,1\n        self.flat_planes = []\n        for i in range(2):\n            self.flat_planes.append(\n                (np.zeros(64, dtype=np.float32) + i).tobytes())\n\n        # set the down-sampling rate\n        self.sample = sample\n        # set the details for diff focus, defaults accept all positions\n        self.diff_focus_min = diff_focus_min\n        self.diff_focus_slope = diff_focus_slope\n        self.diff_focus_q_weight = diff_focus_q_weight\n        self.diff_focus_pol_scale = diff_focus_pol_scale\n        # set the mini-batch size\n        self.batch_size = batch_size\n        # set number of elements in the shuffle buffer.\n        self.shuffle_size = shuffle_size\n        # Start worker processes, leave 2 for TensorFlow\n        if workers is None:\n            workers = max(1, mp.cpu_count() - 2)\n\n        if workers > 0:\n            print(\"Using {} worker processes.\".format(workers))\n\n            # Start the child workers running\n            self.readers = []\n            self.writers = []\n            parent.processes = []\n            self.chunk_filename_queue = mp.Queue(maxsize=4096)\n            for _ in range(workers):\n                read, write = mp.Pipe(duplex=False)\n                p = mp.Process(target=self.task,\n                               args=(self.chunk_filename_queue, write))\n                p.daemon = True\n                parent.processes.append(p)\n                p.start()\n                self.readers.append(read)\n                self.writers.append(write)\n\n            parent.chunk_process = mp.Process(target=chunk_reader,\n                                              args=(chunks,\n                                                    self.chunk_filename_queue))\n            parent.chunk_process.daemon = True\n            parent.chunk_process.start()\n        else:\n            self.chunks = chunks\n\n        self.init_structs()\n\n    def init_structs(self):\n        \"\"\"\n        struct.Struct doesn't pickle, so it needs to be separately\n        constructed in workers.\n        \"\"\"\n        self.v6_struct = struct.Struct(V6_STRUCT_STRING)\n        self.v5_struct = struct.Struct(V5_STRUCT_STRING)\n        self.v4_struct = struct.Struct(V4_STRUCT_STRING)\n        self.v3_struct = struct.Struct(V3_STRUCT_STRING)\n\n    def convert_v6_to_tuple(self, content):\n        \"\"\"\n        Unpack a v6 binary record to 5-tuple (state, policy pi, result, q, m)\n\n        v6 struct format is (8356 bytes total):\n                                  size         1st byte index\n        uint32_t version;                               0\n        uint32_t input_format;                          4\n        float probabilities[1858];  7432 bytes          8\n        uint64_t planes[104];        832 bytes       7440\n        uint8_t castling_us_ooo;                     8272\n        uint8_t castling_us_oo;                      8273\n        uint8_t castling_them_ooo;                   8274\n        uint8_t castling_them_oo;                    8275\n        uint8_t side_to_move_or_enpassant;           8276\n        uint8_t rule50_count;                        8277\n        // Bitfield with the following allocation:\n        //  bit 7: side to move (input type 3)\n        //  bit 6: position marked for deletion by the rescorer (never set by lc0)\n        //  bit 5: game adjudicated (v6)\n        //  bit 4: max game length exceeded (v6)\n        //  bit 3: best_q is for proven best move (v6)\n        //  bit 2: transpose transform (input type 3)\n        //  bit 1: mirror transform (input type 3)\n        //  bit 0: flip transform (input type 3)\n        uint8_t invariance_info;                     8278\n        uint8_t dep_result;                               8279\n        float root_q;                                8280\n        float best_q;                                8284\n        float root_d;                                8288\n        float best_d;                                8292\n        float root_m;      // In plies.              8296\n        float best_m;      // In plies.              8300\n        float plies_left;                            8304\n        float result_q;                              8308\n        float result_d;                              8312\n        float played_q;                              8316\n        float played_d;                              8320\n        float played_m;                              8324\n        // The folowing may be NaN if not found in cache.\n        float orig_q;      // For value repair.      8328\n        float orig_d;                                8332\n        float orig_m;                                8336\n        uint32_t visits;                             8340\n        // Indices in the probabilities array.\n        uint16_t played_idx;                         8344\n        uint16_t best_idx;                           8346\n        uint64_t reserved;                           8348\n        \"\"\"\n        # unpack the V6 content from raw byte array, arbitrarily chose 4 2-byte values\n        # for the 8 \"reserved\" bytes\n        (ver, input_format, probs, planes, us_ooo, us_oo, them_ooo, them_oo,\n         stm, rule50_count, invariance_info, dep_result, root_q, best_q,\n         root_d, best_d, root_m, best_m, plies_left, result_q, result_d,\n         played_q, played_d, played_m, orig_q, orig_d, orig_m, visits,\n         played_idx, best_idx, reserved1, reserved2, reserved3,\n         reserved4) = self.v6_struct.unpack(content)\n        \"\"\"\n        v5 struct format was (8308 bytes total)\n            int32 version (4 bytes)\n            int32 input_format (4 bytes)\n            1858 float32 probabilities (7432 bytes)\n            104 (13*8) packed bit planes of 8 bytes each (832 bytes)\n            uint8 castling us_ooo (1 byte)\n            uint8 castling us_oo (1 byte)\n            uint8 castling them_ooo (1 byte)\n            uint8 castling them_oo (1 byte)\n            uint8 side_to_move (1 byte)\n            uint8 rule50_count (1 byte)\n            uint8 dep_ply_count (1 byte) (unused)\n            int8 result (1 byte)\n            float32 root_q (4 bytes)\n            float32 best_q (4 bytes)\n            float32 root_d (4 bytes)\n            float32 best_d (4 bytes)\n            float32 root_m (4 bytes)\n            float32 best_m (4 bytes)\n            float32 plies_left (4 bytes)\n        \"\"\"\n        # v3/4 data sometimes has a useful value in dep_ply_count (now invariance_info),\n        # so copy that over if the new ply_count is not populated.\n        if plies_left == 0:\n            plies_left = invariance_info\n        plies_left = struct.pack('f', plies_left)\n\n        assert input_format == self.expected_input_format\n\n        # Unpack bit planes and cast to 32 bit float\n        planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(\n            np.float32)\n        rule50_divisor = 99.0\n        if input_format > 3:\n            rule50_divisor = 100.0\n        rule50_plane = struct.pack('f', rule50_count / rule50_divisor) * 64\n\n        if input_format == 1:\n            middle_planes = self.flat_planes[us_ooo] + \\\n                            self.flat_planes[us_oo] + \\\n                            self.flat_planes[them_ooo] + \\\n                            self.flat_planes[them_oo] + \\\n                            self.flat_planes[stm]\n        elif input_format == 2:\n            # Each inner array has to be reversed as these fields are in opposite endian to the planes data.\n            them_ooo_bytes = reverse_expand_bits(them_ooo)\n            us_ooo_bytes = reverse_expand_bits(us_ooo)\n            them_oo_bytes = reverse_expand_bits(them_oo)\n            us_oo_bytes = reverse_expand_bits(us_oo)\n            middle_planes = us_ooo_bytes + (6*8*4) * b'\\x00' + them_ooo_bytes + \\\n                            us_oo_bytes + (6*8*4) * b'\\x00' + them_oo_bytes + \\\n                            self.flat_planes[0] + \\\n                            self.flat_planes[0] + \\\n                            self.flat_planes[stm]\n        elif input_format == 3 or input_format == 4 or input_format == 132 or input_format == 5 or input_format == 133:\n            # Each inner array has to be reversed as these fields are in opposite endian to the planes data.\n            them_ooo_bytes = reverse_expand_bits(them_ooo)\n            us_ooo_bytes = reverse_expand_bits(us_ooo)\n            them_oo_bytes = reverse_expand_bits(them_oo)\n            us_oo_bytes = reverse_expand_bits(us_oo)\n            enpassant_bytes = reverse_expand_bits(stm)\n            middle_planes = us_ooo_bytes + (6*8*4) * b'\\x00' + them_ooo_bytes + \\\n                            us_oo_bytes + (6*8*4) * b'\\x00' + them_oo_bytes + \\\n                            self.flat_planes[0] + \\\n                            self.flat_planes[0] + \\\n                            (7*8*4) * b'\\x00' + enpassant_bytes\n\n        # Concatenate all byteplanes. Make the last plane all 1's so the NN can\n        # detect edges of the board more easily\n        aux_plus_6_plane = self.flat_planes[0]\n        if (input_format == 132\n                or input_format == 133) and invariance_info >= 128:\n            aux_plus_6_plane = self.flat_planes[1]\n        planes = planes.tobytes() + \\\n                 middle_planes + \\\n                 rule50_plane + \\\n                 aux_plus_6_plane + \\\n                 self.flat_planes[1]\n\n        assert len(planes) == ((8 * 13 * 1 + 8 * 1 * 1) * 8 * 8 * 4)\n\n        if ver == V6_VERSION:\n            winner = struct.pack('fff', 0.5 * (1.0 - result_d + result_q),\n                                 result_d, 0.5 * (1.0 - result_d - result_q))\n        else:\n            dep_result = float(dep_result)\n            assert dep_result == 1.0 or dep_result == -1.0 or dep_result == 0.0\n            winner = struct.pack('fff', dep_result == 1.0, dep_result == 0.0,\n                                 dep_result == -1.0)\n\n        best_q_w = 0.5 * (1.0 - best_d + best_q)\n        best_q_l = 0.5 * (1.0 - best_d - best_q)\n        assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0\n        best_q = struct.pack('fff', best_q_w, best_d, best_q_l)\n\n        return (planes, probs, winner, best_q, plies_left)\n\n    def sample_record(self, chunkdata):\n        \"\"\"\n        Randomly sample through the v3/4/5/6 chunk data and select records in v6 format\n        Downsampling to avoid highly correlated positions skips most records, and \n        diff focus may also skip some records.\n        \"\"\"\n        version = chunkdata[0:4]\n        if version == V6_VERSION:\n            record_size = self.v6_struct.size\n        elif version == V5_VERSION:\n            record_size = self.v5_struct.size\n        elif version == V4_VERSION:\n            record_size = self.v4_struct.size\n        elif version == V3_VERSION:\n            record_size = self.v3_struct.size\n        else:\n            return\n\n        for i in range(0, len(chunkdata), record_size):\n            if self.sample > 1:\n                # Downsample, using only 1/Nth of the items.\n                if random.randint(0, self.sample - 1) != 0:\n                    continue  # Skip this record.\n\n            record = chunkdata[i:i + record_size]\n            # for earlier versions, append fake bytes to record to maintain size\n            if version == V3_VERSION:\n                # add 16 bytes of fake root_q, best_q, root_d, best_d to match V4 format\n                record += 16 * b'\\x00'\n            if version == V3_VERSION or version == V4_VERSION:\n                # add 12 bytes of fake root_m, best_m, plies_left to match V5 format\n                record += 12 * b'\\x00'\n                # insert 4 bytes of classical input format tag to match v5 format\n                record = record[:4] + CLASSICAL_INPUT + record[4:]\n            if version == V3_VERSION or version == V4_VERSION or version == V5_VERSION:\n                # add 48 byes of fake result_q, result_d etc\n                record += 48 * b'\\x00'\n\n            if version == V6_VERSION:\n                # diff focus code, peek at best_q, orig_q and pol_kld from record (unpacks as tuple with one item)\n                best_q = struct.unpack('f', record[8284:8288])[0]\n                orig_q = struct.unpack('f', record[8328:8332])[0]\n                pol_kld = struct.unpack('f', record[8348:8352])[0]\n\n                # if orig_q is NaN or pol_kld is 0, accept, else accept based on diff focus\n                if not np.isnan(orig_q) and pol_kld > 0:\n                    diff_q = abs(best_q - orig_q)\n                    q_weight = self.diff_focus_q_weight\n                    pol_scale = self.diff_focus_pol_scale\n                    total = (q_weight * diff_q + pol_kld) / (q_weight +\n                                                             pol_scale)\n                    thresh_p = self.diff_focus_min + self.diff_focus_slope * total\n                    if thresh_p < 1.0 and random.random() > thresh_p:\n                        continue\n\n            yield record\n\n    def single_file_gen(self, filename):\n        try:\n            with gzip.open(filename, 'rb') as chunk_file:\n                version = chunk_file.read(4)\n                chunk_file.seek(0)\n                if version == V6_VERSION:\n                    record_size = self.v6_struct.size\n                elif version == V5_VERSION:\n                    record_size = self.v5_struct.size\n                elif version == V4_VERSION:\n                    record_size = self.v4_struct.size\n                elif version == V3_VERSION:\n                    record_size = self.v3_struct.size\n                else:\n                    print('Unknown version {} in file {}'.format(\n                        version, filename))\n                    return\n                while True:\n                    chunkdata = chunk_file.read(256 * record_size)\n                    if len(chunkdata) == 0:\n                        break\n                    for item in self.sample_record(chunkdata):\n                        yield item\n\n        except:\n            print(\"failed to parse {}\".format(filename))\n\n    def sequential_gen(self):\n        for filename in self.chunks:\n            for item in self.single_file_gen(filename):\n                yield item\n\n    def sequential(self):\n        gen = self.sequential_gen()  # read from all files in order in this process.\n        gen = self.tuple_gen(gen)  # convert v6->tuple\n        gen = self.batch_gen(gen, allow_partial=False)  # assemble into batches\n        for b in gen:\n            yield b\n\n    def task(self, chunk_filename_queue, writer):\n        \"\"\"\n        Run in fork'ed process, read data from chunkdatasrc, parsing, shuffling and\n        sending v6 data through pipe back to main process.\n        \"\"\"\n        self.init_structs()\n        while True:\n            filename = chunk_filename_queue.get()\n            for item in self.single_file_gen(filename):\n                writer.send_bytes(item)\n\n    def v6_gen(self):\n        \"\"\"\n        Read v6 records from child workers, shuffle, and yield\n        records.\n        \"\"\"\n        sbuff = sb.ShuffleBuffer(self.v6_struct.size, self.shuffle_size)\n        while len(self.readers):\n            for r in self.readers:\n                try:\n                    s = r.recv_bytes()\n                    s = sbuff.insert_or_replace(s)\n                    if s is None:\n                        continue  # shuffle buffer not yet full\n                    yield s\n                except EOFError:\n                    print(\"Reader EOF\")\n                    self.readers.remove(r)\n        # drain the shuffle buffer.\n        while True:\n            s = sbuff.extract()\n            if s is None:\n                return\n            yield s\n\n    def tuple_gen(self, gen):\n        \"\"\"\n        Take a generator producing v6 records and convert them to tuples.\n        applying a random symmetry on the way.\n        \"\"\"\n        for r in gen:\n            yield self.convert_v6_to_tuple(r)\n\n    def batch_gen(self, gen, allow_partial=True):\n        \"\"\"\n        Pack multiple records into a single batch\n        \"\"\"\n        # Get N records. We flatten the returned generator to\n        # a list because we need to reuse it.\n        while True:\n            s = list(itertools.islice(gen, self.batch_size))\n            if not len(s) or (not allow_partial and len(s) != self.batch_size):\n                return\n            yield (b''.join([x[0] for x in s]), b''.join([x[1] for x in s]),\n                   b''.join([x[2] for x in s]), b''.join([x[3] for x in s]),\n                   b''.join([x[4] for x in s]))\n\n    def parse(self):\n        \"\"\"\n        Read data from child workers and yield batches of unpacked records\n        \"\"\"\n        gen = self.v6_gen()  # read from workers\n        gen = self.tuple_gen(gen)  # convert v6->tuple\n        gen = self.batch_gen(gen)  # assemble into batches\n        for b in gen:\n            yield b\n\n\n# Tests to check that records parse correctly\nclass ChunkParserTest(unittest.TestCase):\n    def setUp(self):\n        self.v4_struct = struct.Struct(V4_STRUCT_STRING)\n\n    def generate_fake_pos(self):\n        \"\"\"\n        Generate a random game position.\n        Result is ([[64] * 104], [1]*5, [1858], [1], [1])\n        \"\"\"\n        # 0. 104 binary planes of length 64\n        planes = [\n            np.random.randint(2, size=64).tolist() for plane in range(104)\n        ]\n\n        # 1. generate the other integer data\n        integer = np.zeros(7, dtype=np.int32)\n        for i in range(5):\n            integer[i] = np.random.randint(2)\n        integer[5] = np.random.randint(100)\n\n        # 2. 1858 probs\n        probs = np.random.randint(9, size=1858, dtype=np.int32)\n\n        # 3. And a winner: 1, 0, -1\n        winner = np.random.randint(3) - 1\n\n        # 4. evaluation after search\n        best_q = np.random.uniform(-1, 1)\n        best_d = np.random.uniform(0, 1 - np.abs(best_q))\n        return (planes, integer, probs, winner, best_q, best_d)\n\n    def v4_record(self, planes, i, probs, winner, best_q, best_d):\n        pl = []\n        for plane in planes:\n            pl.append(np.packbits(plane))\n        pl = np.array(pl).flatten().tobytes()\n        pi = probs.tobytes()\n        root_q, root_d = 0.0, 0.0\n        return self.v4_struct.pack(V4_VERSION, pi, pl, i[0], i[1], i[2], i[3],\n                                   i[4], i[5], i[6], winner, root_q, best_q,\n                                   root_d, best_d)\n\n    def test_structsize(self):\n        \"\"\"\n        Test struct size\n        \"\"\"\n        self.assertEqual(self.v4_struct.size, 8292)\n\n    def test_parsing(self):\n        \"\"\"\n        Test game position decoding pipeline.\n        \"\"\"\n        truth = self.generate_fake_pos()\n        batch_size = 4\n        records = []\n        for i in range(batch_size):\n            record = b''\n            for j in range(2):\n                record += self.v4_record(*truth)\n            records.append(record)\n\n        parser = ChunkParser(ChunkDataSrc(records),\n                             shuffle_size=1,\n                             workers=1,\n                             batch_size=batch_size)\n        batchgen = parser.parse()\n        data = next(batchgen)\n\n        batch = (np.reshape(np.frombuffer(data[0], dtype=np.float32),\n                            (batch_size, 112, 64)),\n                 np.reshape(np.frombuffer(data[1], dtype=np.int32),\n                            (batch_size, 1858)),\n                 np.reshape(np.frombuffer(data[2], dtype=np.float32),\n                            (batch_size, 3)),\n                 np.reshape(np.frombuffer(data[3], dtype=np.float32),\n                            (batch_size, 3)))\n\n        fltplanes = truth[1].astype(np.float32)\n        fltplanes[5] /= 99\n        for i in range(batch_size):\n            data = (batch[0][i][:104],\n                    np.array([batch[0][i][j][0] for j in range(104, 111)]),\n                    batch[1][i], batch[2][i], batch[3][i])\n            self.assertTrue((data[0] == truth[0]).all())\n            self.assertTrue((data[1] == fltplanes).all())\n            self.assertTrue((data[2] == truth[2]).all())\n            scalar_win = data[3][0] - data[3][-1]\n            self.assertTrue(np.abs(scalar_win - truth[3]) < 1e-6)\n            scalar_q = data[4][0] - data[4][-1]\n            self.assertTrue(np.abs(scalar_q - truth[4]) < 1e-6)\n\n        parser.shutdown()\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tf/configs/example.yaml",
    "content": "%YAML 1.2\n---\nname: 'kb1-64x6'                       # ideally no spaces\ngpu: 0                                 # gpu id to process on\n\ndataset: \n  num_chunks: 100000                   # newest nof chunks to parse\n  train_ratio: 0.90                    # trainingset ratio\n  # For separated test and train data.\n  input_train: '/path/to/chunks/*/draw/' # supports glob\n  input_test: '/path/to/chunks/*/draw/'  # supports glob\n  # For a one-shot run with all data in one directory.\n  # input: '/path/to/chunks/*/draw/'\n\ntraining:\n    batch_size: 2048                   # training batch\n    test_steps: 2000                   # eval test set values after this many steps\n    train_avg_report_steps: 200        # training reports its average values after this many steps.\n    total_steps: 140000                # terminate after these steps\n    warmup_steps: 250                  # if global step is less than this, scale the current LR by ratio of global step to this value\n    # checkpoint_steps: 10000          # optional frequency for checkpointing before finish\n    shuffle_size: 524288               # size of the shuffle buffer\n    lr_values:                         # list of learning rates\n        - 0.02\n        - 0.002\n        - 0.0005\n    lr_boundaries:                     # list of boundaries\n        - 100000\n        - 130000\n    policy_loss_weight: 1.0            # weight of policy loss\n    value_loss_weight: 1.0             # weight of value loss\n    moves_left_loss_weight: 1.0        # weight of moves-left loss\n    path: '/path/to/store/networks'    # network storage dir\n\nmodel:\n  filters: 64\n  residual_blocks: 6\n  se_ratio: 2                          # Squeeze Excite structural network architecture.\n  policy: 'attention'                  # attention policy fields:\n  pol_embedding_size: 64               # embedding vector size\n  pol_encoder_layers: 1                # number of intermediate attention layers in the policy head\n  pol_encoder_heads: 4                 # number of attention heads in encoder layers\n  pol_encoder_d_model: 64              # size of the Q, K, & V vectors in encoder layers -- divisible by encoder_heads\n  pol_encoder_dff: 128                 # size of the largest dense layer in encoder block feed-forward network\n  policy_d_model: 64                   # size of the query and key vectors in final attention layer\n  value: 'wdl'\n  moves_left: 'v1'\n...\n"
  },
  {
    "path": "tf/decode_training.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Chess.\n#    Copyright (C) 2018 Folkert Huizinga\n#    Copyright (C) 2017-2018 Gian-Carlo Pascutto\n#\n#    Leela Chess is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Chess is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\n\nimport array\nimport binascii\nimport chunkparser\nimport glob\nimport gzip\nimport itertools\nimport math\nimport numpy as np\nimport random\nimport re\nimport os\nimport shutil\nimport struct\nimport sys\nimport threading\nimport time\nimport unittest\nimport argparse\nfrom collections import defaultdict\n\n# VERSION of the training data file format\n#     1 - Text, oldflip\n#     2 - Binary, oldflip\n#     3 - Binary, newflip\n#     b'\\1\\0\\0\\0' - Invalid, see issue #119\n#\n# Note: VERSION1 does not include a version in the header, it starts with\n# text hex characters. This means any VERSION that is also a valid ASCII\n# hex string could potentially be a training file. Most of the time it\n# will be \"00ff\", but maybe games could get split and then there are more\n# \"00??\" possibilities.\n#\n# Also note \"0002\" is actually b'\\0x30\\0x30\\0x30\\0x32' (or maybe reversed?)\n# so it doesn't collide with VERSION2.\n#\nVERSION3 = chunkparser.V3_VERSION\nVERSION4 = chunkparser.V4_VERSION\n\nV3_BYTES = 8276\nV4_BYTES = 8292\n\n# Us   -- uppercase\n# Them -- lowercase\nPIECES = \"PNBRQKpnbrqk\"\n\nMOVES = [\n    \"a1b1\", \"a1c1\", \"a1d1\", \"a1e1\", \"a1f1\", \"a1g1\", \"a1h1\", \"a1a2\", \"a1b2\",\n    \"a1c2\", \"a1a3\", \"a1b3\", \"a1c3\", \"a1a4\", \"a1d4\", \"a1a5\", \"a1e5\", \"a1a6\",\n    \"a1f6\", \"a1a7\", \"a1g7\", \"a1a8\", \"a1h8\", \"b1a1\", \"b1c1\", \"b1d1\", \"b1e1\",\n    \"b1f1\", \"b1g1\", \"b1h1\", \"b1a2\", \"b1b2\", \"b1c2\", \"b1d2\", \"b1a3\", \"b1b3\",\n    \"b1c3\", \"b1d3\", \"b1b4\", \"b1e4\", \"b1b5\", \"b1f5\", \"b1b6\", \"b1g6\", \"b1b7\",\n    \"b1h7\", \"b1b8\", \"c1a1\", \"c1b1\", \"c1d1\", \"c1e1\", \"c1f1\", \"c1g1\", \"c1h1\",\n    \"c1a2\", \"c1b2\", \"c1c2\", \"c1d2\", \"c1e2\", \"c1a3\", \"c1b3\", \"c1c3\", \"c1d3\",\n    \"c1e3\", \"c1c4\", \"c1f4\", \"c1c5\", \"c1g5\", \"c1c6\", \"c1h6\", \"c1c7\", \"c1c8\",\n    \"d1a1\", \"d1b1\", \"d1c1\", \"d1e1\", \"d1f1\", \"d1g1\", \"d1h1\", \"d1b2\", \"d1c2\",\n    \"d1d2\", \"d1e2\", \"d1f2\", \"d1b3\", \"d1c3\", \"d1d3\", \"d1e3\", \"d1f3\", \"d1a4\",\n    \"d1d4\", \"d1g4\", \"d1d5\", \"d1h5\", \"d1d6\", \"d1d7\", \"d1d8\", \"e1a1\", \"e1b1\",\n    \"e1c1\", \"e1d1\", \"e1f1\", \"e1g1\", \"e1h1\", \"e1c2\", \"e1d2\", \"e1e2\", \"e1f2\",\n    \"e1g2\", \"e1c3\", \"e1d3\", \"e1e3\", \"e1f3\", \"e1g3\", \"e1b4\", \"e1e4\", \"e1h4\",\n    \"e1a5\", \"e1e5\", \"e1e6\", \"e1e7\", \"e1e8\", \"f1a1\", \"f1b1\", \"f1c1\", \"f1d1\",\n    \"f1e1\", \"f1g1\", \"f1h1\", \"f1d2\", \"f1e2\", \"f1f2\", \"f1g2\", \"f1h2\", \"f1d3\",\n    \"f1e3\", \"f1f3\", \"f1g3\", \"f1h3\", \"f1c4\", \"f1f4\", \"f1b5\", \"f1f5\", \"f1a6\",\n    \"f1f6\", \"f1f7\", \"f1f8\", \"g1a1\", \"g1b1\", \"g1c1\", \"g1d1\", \"g1e1\", \"g1f1\",\n    \"g1h1\", \"g1e2\", \"g1f2\", \"g1g2\", \"g1h2\", \"g1e3\", \"g1f3\", \"g1g3\", \"g1h3\",\n    \"g1d4\", \"g1g4\", \"g1c5\", \"g1g5\", \"g1b6\", \"g1g6\", \"g1a7\", \"g1g7\", \"g1g8\",\n    \"h1a1\", \"h1b1\", \"h1c1\", \"h1d1\", \"h1e1\", \"h1f1\", \"h1g1\", \"h1f2\", \"h1g2\",\n    \"h1h2\", \"h1f3\", \"h1g3\", \"h1h3\", \"h1e4\", \"h1h4\", \"h1d5\", \"h1h5\", \"h1c6\",\n    \"h1h6\", \"h1b7\", \"h1h7\", \"h1a8\", \"h1h8\", \"a2a1\", \"a2b1\", \"a2c1\", \"a2b2\",\n    \"a2c2\", \"a2d2\", \"a2e2\", \"a2f2\", \"a2g2\", \"a2h2\", \"a2a3\", \"a2b3\", \"a2c3\",\n    \"a2a4\", \"a2b4\", \"a2c4\", \"a2a5\", \"a2d5\", \"a2a6\", \"a2e6\", \"a2a7\", \"a2f7\",\n    \"a2a8\", \"a2g8\", \"b2a1\", \"b2b1\", \"b2c1\", \"b2d1\", \"b2a2\", \"b2c2\", \"b2d2\",\n    \"b2e2\", \"b2f2\", \"b2g2\", \"b2h2\", \"b2a3\", \"b2b3\", \"b2c3\", \"b2d3\", \"b2a4\",\n    \"b2b4\", \"b2c4\", \"b2d4\", \"b2b5\", \"b2e5\", \"b2b6\", \"b2f6\", \"b2b7\", \"b2g7\",\n    \"b2b8\", \"b2h8\", \"c2a1\", \"c2b1\", \"c2c1\", \"c2d1\", \"c2e1\", \"c2a2\", \"c2b2\",\n    \"c2d2\", \"c2e2\", \"c2f2\", \"c2g2\", \"c2h2\", \"c2a3\", \"c2b3\", \"c2c3\", \"c2d3\",\n    \"c2e3\", \"c2a4\", \"c2b4\", \"c2c4\", \"c2d4\", \"c2e4\", \"c2c5\", \"c2f5\", \"c2c6\",\n    \"c2g6\", \"c2c7\", \"c2h7\", \"c2c8\", \"d2b1\", \"d2c1\", \"d2d1\", \"d2e1\", \"d2f1\",\n    \"d2a2\", \"d2b2\", \"d2c2\", \"d2e2\", \"d2f2\", \"d2g2\", \"d2h2\", \"d2b3\", \"d2c3\",\n    \"d2d3\", \"d2e3\", \"d2f3\", \"d2b4\", \"d2c4\", \"d2d4\", \"d2e4\", \"d2f4\", \"d2a5\",\n    \"d2d5\", \"d2g5\", \"d2d6\", \"d2h6\", \"d2d7\", \"d2d8\", \"e2c1\", \"e2d1\", \"e2e1\",\n    \"e2f1\", \"e2g1\", \"e2a2\", \"e2b2\", \"e2c2\", \"e2d2\", \"e2f2\", \"e2g2\", \"e2h2\",\n    \"e2c3\", \"e2d3\", \"e2e3\", \"e2f3\", \"e2g3\", \"e2c4\", \"e2d4\", \"e2e4\", \"e2f4\",\n    \"e2g4\", \"e2b5\", \"e2e5\", \"e2h5\", \"e2a6\", \"e2e6\", \"e2e7\", \"e2e8\", \"f2d1\",\n    \"f2e1\", \"f2f1\", \"f2g1\", \"f2h1\", \"f2a2\", \"f2b2\", \"f2c2\", \"f2d2\", \"f2e2\",\n    \"f2g2\", \"f2h2\", \"f2d3\", \"f2e3\", \"f2f3\", \"f2g3\", \"f2h3\", \"f2d4\", \"f2e4\",\n    \"f2f4\", \"f2g4\", \"f2h4\", \"f2c5\", \"f2f5\", \"f2b6\", \"f2f6\", \"f2a7\", \"f2f7\",\n    \"f2f8\", \"g2e1\", \"g2f1\", \"g2g1\", \"g2h1\", \"g2a2\", \"g2b2\", \"g2c2\", \"g2d2\",\n    \"g2e2\", \"g2f2\", \"g2h2\", \"g2e3\", \"g2f3\", \"g2g3\", \"g2h3\", \"g2e4\", \"g2f4\",\n    \"g2g4\", \"g2h4\", \"g2d5\", \"g2g5\", \"g2c6\", \"g2g6\", \"g2b7\", \"g2g7\", \"g2a8\",\n    \"g2g8\", \"h2f1\", \"h2g1\", \"h2h1\", \"h2a2\", \"h2b2\", \"h2c2\", \"h2d2\", \"h2e2\",\n    \"h2f2\", \"h2g2\", \"h2f3\", \"h2g3\", \"h2h3\", \"h2f4\", \"h2g4\", \"h2h4\", \"h2e5\",\n    \"h2h5\", \"h2d6\", \"h2h6\", \"h2c7\", \"h2h7\", \"h2b8\", \"h2h8\", \"a3a1\", \"a3b1\",\n    \"a3c1\", \"a3a2\", \"a3b2\", \"a3c2\", \"a3b3\", \"a3c3\", \"a3d3\", \"a3e3\", \"a3f3\",\n    \"a3g3\", \"a3h3\", \"a3a4\", \"a3b4\", \"a3c4\", \"a3a5\", \"a3b5\", \"a3c5\", \"a3a6\",\n    \"a3d6\", \"a3a7\", \"a3e7\", \"a3a8\", \"a3f8\", \"b3a1\", \"b3b1\", \"b3c1\", \"b3d1\",\n    \"b3a2\", \"b3b2\", \"b3c2\", \"b3d2\", \"b3a3\", \"b3c3\", \"b3d3\", \"b3e3\", \"b3f3\",\n    \"b3g3\", \"b3h3\", \"b3a4\", \"b3b4\", \"b3c4\", \"b3d4\", \"b3a5\", \"b3b5\", \"b3c5\",\n    \"b3d5\", \"b3b6\", \"b3e6\", \"b3b7\", \"b3f7\", \"b3b8\", \"b3g8\", \"c3a1\", \"c3b1\",\n    \"c3c1\", \"c3d1\", \"c3e1\", \"c3a2\", \"c3b2\", \"c3c2\", \"c3d2\", \"c3e2\", \"c3a3\",\n    \"c3b3\", \"c3d3\", \"c3e3\", \"c3f3\", \"c3g3\", \"c3h3\", \"c3a4\", \"c3b4\", \"c3c4\",\n    \"c3d4\", \"c3e4\", \"c3a5\", \"c3b5\", \"c3c5\", \"c3d5\", \"c3e5\", \"c3c6\", \"c3f6\",\n    \"c3c7\", \"c3g7\", \"c3c8\", \"c3h8\", \"d3b1\", \"d3c1\", \"d3d1\", \"d3e1\", \"d3f1\",\n    \"d3b2\", \"d3c2\", \"d3d2\", \"d3e2\", \"d3f2\", \"d3a3\", \"d3b3\", \"d3c3\", \"d3e3\",\n    \"d3f3\", \"d3g3\", \"d3h3\", \"d3b4\", \"d3c4\", \"d3d4\", \"d3e4\", \"d3f4\", \"d3b5\",\n    \"d3c5\", \"d3d5\", \"d3e5\", \"d3f5\", \"d3a6\", \"d3d6\", \"d3g6\", \"d3d7\", \"d3h7\",\n    \"d3d8\", \"e3c1\", \"e3d1\", \"e3e1\", \"e3f1\", \"e3g1\", \"e3c2\", \"e3d2\", \"e3e2\",\n    \"e3f2\", \"e3g2\", \"e3a3\", \"e3b3\", \"e3c3\", \"e3d3\", \"e3f3\", \"e3g3\", \"e3h3\",\n    \"e3c4\", \"e3d4\", \"e3e4\", \"e3f4\", \"e3g4\", \"e3c5\", \"e3d5\", \"e3e5\", \"e3f5\",\n    \"e3g5\", \"e3b6\", \"e3e6\", \"e3h6\", \"e3a7\", \"e3e7\", \"e3e8\", \"f3d1\", \"f3e1\",\n    \"f3f1\", \"f3g1\", \"f3h1\", \"f3d2\", \"f3e2\", \"f3f2\", \"f3g2\", \"f3h2\", \"f3a3\",\n    \"f3b3\", \"f3c3\", \"f3d3\", \"f3e3\", \"f3g3\", \"f3h3\", \"f3d4\", \"f3e4\", \"f3f4\",\n    \"f3g4\", \"f3h4\", \"f3d5\", \"f3e5\", \"f3f5\", \"f3g5\", \"f3h5\", \"f3c6\", \"f3f6\",\n    \"f3b7\", \"f3f7\", \"f3a8\", \"f3f8\", \"g3e1\", \"g3f1\", \"g3g1\", \"g3h1\", \"g3e2\",\n    \"g3f2\", \"g3g2\", \"g3h2\", \"g3a3\", \"g3b3\", \"g3c3\", \"g3d3\", \"g3e3\", \"g3f3\",\n    \"g3h3\", \"g3e4\", \"g3f4\", \"g3g4\", \"g3h4\", \"g3e5\", \"g3f5\", \"g3g5\", \"g3h5\",\n    \"g3d6\", \"g3g6\", \"g3c7\", \"g3g7\", \"g3b8\", \"g3g8\", \"h3f1\", \"h3g1\", \"h3h1\",\n    \"h3f2\", \"h3g2\", \"h3h2\", \"h3a3\", \"h3b3\", \"h3c3\", \"h3d3\", \"h3e3\", \"h3f3\",\n    \"h3g3\", \"h3f4\", \"h3g4\", \"h3h4\", \"h3f5\", \"h3g5\", \"h3h5\", \"h3e6\", \"h3h6\",\n    \"h3d7\", \"h3h7\", \"h3c8\", \"h3h8\", \"a4a1\", \"a4d1\", \"a4a2\", \"a4b2\", \"a4c2\",\n    \"a4a3\", \"a4b3\", \"a4c3\", \"a4b4\", \"a4c4\", \"a4d4\", \"a4e4\", \"a4f4\", \"a4g4\",\n    \"a4h4\", \"a4a5\", \"a4b5\", \"a4c5\", \"a4a6\", \"a4b6\", \"a4c6\", \"a4a7\", \"a4d7\",\n    \"a4a8\", \"a4e8\", \"b4b1\", \"b4e1\", \"b4a2\", \"b4b2\", \"b4c2\", \"b4d2\", \"b4a3\",\n    \"b4b3\", \"b4c3\", \"b4d3\", \"b4a4\", \"b4c4\", \"b4d4\", \"b4e4\", \"b4f4\", \"b4g4\",\n    \"b4h4\", \"b4a5\", \"b4b5\", \"b4c5\", \"b4d5\", \"b4a6\", \"b4b6\", \"b4c6\", \"b4d6\",\n    \"b4b7\", \"b4e7\", \"b4b8\", \"b4f8\", \"c4c1\", \"c4f1\", \"c4a2\", \"c4b2\", \"c4c2\",\n    \"c4d2\", \"c4e2\", \"c4a3\", \"c4b3\", \"c4c3\", \"c4d3\", \"c4e3\", \"c4a4\", \"c4b4\",\n    \"c4d4\", \"c4e4\", \"c4f4\", \"c4g4\", \"c4h4\", \"c4a5\", \"c4b5\", \"c4c5\", \"c4d5\",\n    \"c4e5\", \"c4a6\", \"c4b6\", \"c4c6\", \"c4d6\", \"c4e6\", \"c4c7\", \"c4f7\", \"c4c8\",\n    \"c4g8\", \"d4a1\", \"d4d1\", \"d4g1\", \"d4b2\", \"d4c2\", \"d4d2\", \"d4e2\", \"d4f2\",\n    \"d4b3\", \"d4c3\", \"d4d3\", \"d4e3\", \"d4f3\", \"d4a4\", \"d4b4\", \"d4c4\", \"d4e4\",\n    \"d4f4\", \"d4g4\", \"d4h4\", \"d4b5\", \"d4c5\", \"d4d5\", \"d4e5\", \"d4f5\", \"d4b6\",\n    \"d4c6\", \"d4d6\", \"d4e6\", \"d4f6\", \"d4a7\", \"d4d7\", \"d4g7\", \"d4d8\", \"d4h8\",\n    \"e4b1\", \"e4e1\", \"e4h1\", \"e4c2\", \"e4d2\", \"e4e2\", \"e4f2\", \"e4g2\", \"e4c3\",\n    \"e4d3\", \"e4e3\", \"e4f3\", \"e4g3\", \"e4a4\", \"e4b4\", \"e4c4\", \"e4d4\", \"e4f4\",\n    \"e4g4\", \"e4h4\", \"e4c5\", \"e4d5\", \"e4e5\", \"e4f5\", \"e4g5\", \"e4c6\", \"e4d6\",\n    \"e4e6\", \"e4f6\", \"e4g6\", \"e4b7\", \"e4e7\", \"e4h7\", \"e4a8\", \"e4e8\", \"f4c1\",\n    \"f4f1\", \"f4d2\", \"f4e2\", \"f4f2\", \"f4g2\", \"f4h2\", \"f4d3\", \"f4e3\", \"f4f3\",\n    \"f4g3\", \"f4h3\", \"f4a4\", \"f4b4\", \"f4c4\", \"f4d4\", \"f4e4\", \"f4g4\", \"f4h4\",\n    \"f4d5\", \"f4e5\", \"f4f5\", \"f4g5\", \"f4h5\", \"f4d6\", \"f4e6\", \"f4f6\", \"f4g6\",\n    \"f4h6\", \"f4c7\", \"f4f7\", \"f4b8\", \"f4f8\", \"g4d1\", \"g4g1\", \"g4e2\", \"g4f2\",\n    \"g4g2\", \"g4h2\", \"g4e3\", \"g4f3\", \"g4g3\", \"g4h3\", \"g4a4\", \"g4b4\", \"g4c4\",\n    \"g4d4\", \"g4e4\", \"g4f4\", \"g4h4\", \"g4e5\", \"g4f5\", \"g4g5\", \"g4h5\", \"g4e6\",\n    \"g4f6\", \"g4g6\", \"g4h6\", \"g4d7\", \"g4g7\", \"g4c8\", \"g4g8\", \"h4e1\", \"h4h1\",\n    \"h4f2\", \"h4g2\", \"h4h2\", \"h4f3\", \"h4g3\", \"h4h3\", \"h4a4\", \"h4b4\", \"h4c4\",\n    \"h4d4\", \"h4e4\", \"h4f4\", \"h4g4\", \"h4f5\", \"h4g5\", \"h4h5\", \"h4f6\", \"h4g6\",\n    \"h4h6\", \"h4e7\", \"h4h7\", \"h4d8\", \"h4h8\", \"a5a1\", \"a5e1\", \"a5a2\", \"a5d2\",\n    \"a5a3\", \"a5b3\", \"a5c3\", \"a5a4\", \"a5b4\", \"a5c4\", \"a5b5\", \"a5c5\", \"a5d5\",\n    \"a5e5\", \"a5f5\", \"a5g5\", \"a5h5\", \"a5a6\", \"a5b6\", \"a5c6\", \"a5a7\", \"a5b7\",\n    \"a5c7\", \"a5a8\", \"a5d8\", \"b5b1\", \"b5f1\", \"b5b2\", \"b5e2\", \"b5a3\", \"b5b3\",\n    \"b5c3\", \"b5d3\", \"b5a4\", \"b5b4\", \"b5c4\", \"b5d4\", \"b5a5\", \"b5c5\", \"b5d5\",\n    \"b5e5\", \"b5f5\", \"b5g5\", \"b5h5\", \"b5a6\", \"b5b6\", \"b5c6\", \"b5d6\", \"b5a7\",\n    \"b5b7\", \"b5c7\", \"b5d7\", \"b5b8\", \"b5e8\", \"c5c1\", \"c5g1\", \"c5c2\", \"c5f2\",\n    \"c5a3\", \"c5b3\", \"c5c3\", \"c5d3\", \"c5e3\", \"c5a4\", \"c5b4\", \"c5c4\", \"c5d4\",\n    \"c5e4\", \"c5a5\", \"c5b5\", \"c5d5\", \"c5e5\", \"c5f5\", \"c5g5\", \"c5h5\", \"c5a6\",\n    \"c5b6\", \"c5c6\", \"c5d6\", \"c5e6\", \"c5a7\", \"c5b7\", \"c5c7\", \"c5d7\", \"c5e7\",\n    \"c5c8\", \"c5f8\", \"d5d1\", \"d5h1\", \"d5a2\", \"d5d2\", \"d5g2\", \"d5b3\", \"d5c3\",\n    \"d5d3\", \"d5e3\", \"d5f3\", \"d5b4\", \"d5c4\", \"d5d4\", \"d5e4\", \"d5f4\", \"d5a5\",\n    \"d5b5\", \"d5c5\", \"d5e5\", \"d5f5\", \"d5g5\", \"d5h5\", \"d5b6\", \"d5c6\", \"d5d6\",\n    \"d5e6\", \"d5f6\", \"d5b7\", \"d5c7\", \"d5d7\", \"d5e7\", \"d5f7\", \"d5a8\", \"d5d8\",\n    \"d5g8\", \"e5a1\", \"e5e1\", \"e5b2\", \"e5e2\", \"e5h2\", \"e5c3\", \"e5d3\", \"e5e3\",\n    \"e5f3\", \"e5g3\", \"e5c4\", \"e5d4\", \"e5e4\", \"e5f4\", \"e5g4\", \"e5a5\", \"e5b5\",\n    \"e5c5\", \"e5d5\", \"e5f5\", \"e5g5\", \"e5h5\", \"e5c6\", \"e5d6\", \"e5e6\", \"e5f6\",\n    \"e5g6\", \"e5c7\", \"e5d7\", \"e5e7\", \"e5f7\", \"e5g7\", \"e5b8\", \"e5e8\", \"e5h8\",\n    \"f5b1\", \"f5f1\", \"f5c2\", \"f5f2\", \"f5d3\", \"f5e3\", \"f5f3\", \"f5g3\", \"f5h3\",\n    \"f5d4\", \"f5e4\", \"f5f4\", \"f5g4\", \"f5h4\", \"f5a5\", \"f5b5\", \"f5c5\", \"f5d5\",\n    \"f5e5\", \"f5g5\", \"f5h5\", \"f5d6\", \"f5e6\", \"f5f6\", \"f5g6\", \"f5h6\", \"f5d7\",\n    \"f5e7\", \"f5f7\", \"f5g7\", \"f5h7\", \"f5c8\", \"f5f8\", \"g5c1\", \"g5g1\", \"g5d2\",\n    \"g5g2\", \"g5e3\", \"g5f3\", \"g5g3\", \"g5h3\", \"g5e4\", \"g5f4\", \"g5g4\", \"g5h4\",\n    \"g5a5\", \"g5b5\", \"g5c5\", \"g5d5\", \"g5e5\", \"g5f5\", \"g5h5\", \"g5e6\", \"g5f6\",\n    \"g5g6\", \"g5h6\", \"g5e7\", \"g5f7\", \"g5g7\", \"g5h7\", \"g5d8\", \"g5g8\", \"h5d1\",\n    \"h5h1\", \"h5e2\", \"h5h2\", \"h5f3\", \"h5g3\", \"h5h3\", \"h5f4\", \"h5g4\", \"h5h4\",\n    \"h5a5\", \"h5b5\", \"h5c5\", \"h5d5\", \"h5e5\", \"h5f5\", \"h5g5\", \"h5f6\", \"h5g6\",\n    \"h5h6\", \"h5f7\", \"h5g7\", \"h5h7\", \"h5e8\", \"h5h8\", \"a6a1\", \"a6f1\", \"a6a2\",\n    \"a6e2\", \"a6a3\", \"a6d3\", \"a6a4\", \"a6b4\", \"a6c4\", \"a6a5\", \"a6b5\", \"a6c5\",\n    \"a6b6\", \"a6c6\", \"a6d6\", \"a6e6\", \"a6f6\", \"a6g6\", \"a6h6\", \"a6a7\", \"a6b7\",\n    \"a6c7\", \"a6a8\", \"a6b8\", \"a6c8\", \"b6b1\", \"b6g1\", \"b6b2\", \"b6f2\", \"b6b3\",\n    \"b6e3\", \"b6a4\", \"b6b4\", \"b6c4\", \"b6d4\", \"b6a5\", \"b6b5\", \"b6c5\", \"b6d5\",\n    \"b6a6\", \"b6c6\", \"b6d6\", \"b6e6\", \"b6f6\", \"b6g6\", \"b6h6\", \"b6a7\", \"b6b7\",\n    \"b6c7\", \"b6d7\", \"b6a8\", \"b6b8\", \"b6c8\", \"b6d8\", \"c6c1\", \"c6h1\", \"c6c2\",\n    \"c6g2\", \"c6c3\", \"c6f3\", \"c6a4\", \"c6b4\", \"c6c4\", \"c6d4\", \"c6e4\", \"c6a5\",\n    \"c6b5\", \"c6c5\", \"c6d5\", \"c6e5\", \"c6a6\", \"c6b6\", \"c6d6\", \"c6e6\", \"c6f6\",\n    \"c6g6\", \"c6h6\", \"c6a7\", \"c6b7\", \"c6c7\", \"c6d7\", \"c6e7\", \"c6a8\", \"c6b8\",\n    \"c6c8\", \"c6d8\", \"c6e8\", \"d6d1\", \"d6d2\", \"d6h2\", \"d6a3\", \"d6d3\", \"d6g3\",\n    \"d6b4\", \"d6c4\", \"d6d4\", \"d6e4\", \"d6f4\", \"d6b5\", \"d6c5\", \"d6d5\", \"d6e5\",\n    \"d6f5\", \"d6a6\", \"d6b6\", \"d6c6\", \"d6e6\", \"d6f6\", \"d6g6\", \"d6h6\", \"d6b7\",\n    \"d6c7\", \"d6d7\", \"d6e7\", \"d6f7\", \"d6b8\", \"d6c8\", \"d6d8\", \"d6e8\", \"d6f8\",\n    \"e6e1\", \"e6a2\", \"e6e2\", \"e6b3\", \"e6e3\", \"e6h3\", \"e6c4\", \"e6d4\", \"e6e4\",\n    \"e6f4\", \"e6g4\", \"e6c5\", \"e6d5\", \"e6e5\", \"e6f5\", \"e6g5\", \"e6a6\", \"e6b6\",\n    \"e6c6\", \"e6d6\", \"e6f6\", \"e6g6\", \"e6h6\", \"e6c7\", \"e6d7\", \"e6e7\", \"e6f7\",\n    \"e6g7\", \"e6c8\", \"e6d8\", \"e6e8\", \"e6f8\", \"e6g8\", \"f6a1\", \"f6f1\", \"f6b2\",\n    \"f6f2\", \"f6c3\", \"f6f3\", \"f6d4\", \"f6e4\", \"f6f4\", \"f6g4\", \"f6h4\", \"f6d5\",\n    \"f6e5\", \"f6f5\", \"f6g5\", \"f6h5\", \"f6a6\", \"f6b6\", \"f6c6\", \"f6d6\", \"f6e6\",\n    \"f6g6\", \"f6h6\", \"f6d7\", \"f6e7\", \"f6f7\", \"f6g7\", \"f6h7\", \"f6d8\", \"f6e8\",\n    \"f6f8\", \"f6g8\", \"f6h8\", \"g6b1\", \"g6g1\", \"g6c2\", \"g6g2\", \"g6d3\", \"g6g3\",\n    \"g6e4\", \"g6f4\", \"g6g4\", \"g6h4\", \"g6e5\", \"g6f5\", \"g6g5\", \"g6h5\", \"g6a6\",\n    \"g6b6\", \"g6c6\", \"g6d6\", \"g6e6\", \"g6f6\", \"g6h6\", \"g6e7\", \"g6f7\", \"g6g7\",\n    \"g6h7\", \"g6e8\", \"g6f8\", \"g6g8\", \"g6h8\", \"h6c1\", \"h6h1\", \"h6d2\", \"h6h2\",\n    \"h6e3\", \"h6h3\", \"h6f4\", \"h6g4\", \"h6h4\", \"h6f5\", \"h6g5\", \"h6h5\", \"h6a6\",\n    \"h6b6\", \"h6c6\", \"h6d6\", \"h6e6\", \"h6f6\", \"h6g6\", \"h6f7\", \"h6g7\", \"h6h7\",\n    \"h6f8\", \"h6g8\", \"h6h8\", \"a7a1\", \"a7g1\", \"a7a2\", \"a7f2\", \"a7a3\", \"a7e3\",\n    \"a7a4\", \"a7d4\", \"a7a5\", \"a7b5\", \"a7c5\", \"a7a6\", \"a7b6\", \"a7c6\", \"a7b7\",\n    \"a7c7\", \"a7d7\", \"a7e7\", \"a7f7\", \"a7g7\", \"a7h7\", \"a7a8\", \"a7b8\", \"a7c8\",\n    \"b7b1\", \"b7h1\", \"b7b2\", \"b7g2\", \"b7b3\", \"b7f3\", \"b7b4\", \"b7e4\", \"b7a5\",\n    \"b7b5\", \"b7c5\", \"b7d5\", \"b7a6\", \"b7b6\", \"b7c6\", \"b7d6\", \"b7a7\", \"b7c7\",\n    \"b7d7\", \"b7e7\", \"b7f7\", \"b7g7\", \"b7h7\", \"b7a8\", \"b7b8\", \"b7c8\", \"b7d8\",\n    \"c7c1\", \"c7c2\", \"c7h2\", \"c7c3\", \"c7g3\", \"c7c4\", \"c7f4\", \"c7a5\", \"c7b5\",\n    \"c7c5\", \"c7d5\", \"c7e5\", \"c7a6\", \"c7b6\", \"c7c6\", \"c7d6\", \"c7e6\", \"c7a7\",\n    \"c7b7\", \"c7d7\", \"c7e7\", \"c7f7\", \"c7g7\", \"c7h7\", \"c7a8\", \"c7b8\", \"c7c8\",\n    \"c7d8\", \"c7e8\", \"d7d1\", \"d7d2\", \"d7d3\", \"d7h3\", \"d7a4\", \"d7d4\", \"d7g4\",\n    \"d7b5\", \"d7c5\", \"d7d5\", \"d7e5\", \"d7f5\", \"d7b6\", \"d7c6\", \"d7d6\", \"d7e6\",\n    \"d7f6\", \"d7a7\", \"d7b7\", \"d7c7\", \"d7e7\", \"d7f7\", \"d7g7\", \"d7h7\", \"d7b8\",\n    \"d7c8\", \"d7d8\", \"d7e8\", \"d7f8\", \"e7e1\", \"e7e2\", \"e7a3\", \"e7e3\", \"e7b4\",\n    \"e7e4\", \"e7h4\", \"e7c5\", \"e7d5\", \"e7e5\", \"e7f5\", \"e7g5\", \"e7c6\", \"e7d6\",\n    \"e7e6\", \"e7f6\", \"e7g6\", \"e7a7\", \"e7b7\", \"e7c7\", \"e7d7\", \"e7f7\", \"e7g7\",\n    \"e7h7\", \"e7c8\", \"e7d8\", \"e7e8\", \"e7f8\", \"e7g8\", \"f7f1\", \"f7a2\", \"f7f2\",\n    \"f7b3\", \"f7f3\", \"f7c4\", \"f7f4\", \"f7d5\", \"f7e5\", \"f7f5\", \"f7g5\", \"f7h5\",\n    \"f7d6\", \"f7e6\", \"f7f6\", \"f7g6\", \"f7h6\", \"f7a7\", \"f7b7\", \"f7c7\", \"f7d7\",\n    \"f7e7\", \"f7g7\", \"f7h7\", \"f7d8\", \"f7e8\", \"f7f8\", \"f7g8\", \"f7h8\", \"g7a1\",\n    \"g7g1\", \"g7b2\", \"g7g2\", \"g7c3\", \"g7g3\", \"g7d4\", \"g7g4\", \"g7e5\", \"g7f5\",\n    \"g7g5\", \"g7h5\", \"g7e6\", \"g7f6\", \"g7g6\", \"g7h6\", \"g7a7\", \"g7b7\", \"g7c7\",\n    \"g7d7\", \"g7e7\", \"g7f7\", \"g7h7\", \"g7e8\", \"g7f8\", \"g7g8\", \"g7h8\", \"h7b1\",\n    \"h7h1\", \"h7c2\", \"h7h2\", \"h7d3\", \"h7h3\", \"h7e4\", \"h7h4\", \"h7f5\", \"h7g5\",\n    \"h7h5\", \"h7f6\", \"h7g6\", \"h7h6\", \"h7a7\", \"h7b7\", \"h7c7\", \"h7d7\", \"h7e7\",\n    \"h7f7\", \"h7g7\", \"h7f8\", \"h7g8\", \"h7h8\", \"a8a1\", \"a8h1\", \"a8a2\", \"a8g2\",\n    \"a8a3\", \"a8f3\", \"a8a4\", \"a8e4\", \"a8a5\", \"a8d5\", \"a8a6\", \"a8b6\", \"a8c6\",\n    \"a8a7\", \"a8b7\", \"a8c7\", \"a8b8\", \"a8c8\", \"a8d8\", \"a8e8\", \"a8f8\", \"a8g8\",\n    \"a8h8\", \"b8b1\", \"b8b2\", \"b8h2\", \"b8b3\", \"b8g3\", \"b8b4\", \"b8f4\", \"b8b5\",\n    \"b8e5\", \"b8a6\", \"b8b6\", \"b8c6\", \"b8d6\", \"b8a7\", \"b8b7\", \"b8c7\", \"b8d7\",\n    \"b8a8\", \"b8c8\", \"b8d8\", \"b8e8\", \"b8f8\", \"b8g8\", \"b8h8\", \"c8c1\", \"c8c2\",\n    \"c8c3\", \"c8h3\", \"c8c4\", \"c8g4\", \"c8c5\", \"c8f5\", \"c8a6\", \"c8b6\", \"c8c6\",\n    \"c8d6\", \"c8e6\", \"c8a7\", \"c8b7\", \"c8c7\", \"c8d7\", \"c8e7\", \"c8a8\", \"c8b8\",\n    \"c8d8\", \"c8e8\", \"c8f8\", \"c8g8\", \"c8h8\", \"d8d1\", \"d8d2\", \"d8d3\", \"d8d4\",\n    \"d8h4\", \"d8a5\", \"d8d5\", \"d8g5\", \"d8b6\", \"d8c6\", \"d8d6\", \"d8e6\", \"d8f6\",\n    \"d8b7\", \"d8c7\", \"d8d7\", \"d8e7\", \"d8f7\", \"d8a8\", \"d8b8\", \"d8c8\", \"d8e8\",\n    \"d8f8\", \"d8g8\", \"d8h8\", \"e8e1\", \"e8e2\", \"e8e3\", \"e8a4\", \"e8e4\", \"e8b5\",\n    \"e8e5\", \"e8h5\", \"e8c6\", \"e8d6\", \"e8e6\", \"e8f6\", \"e8g6\", \"e8c7\", \"e8d7\",\n    \"e8e7\", \"e8f7\", \"e8g7\", \"e8a8\", \"e8b8\", \"e8c8\", \"e8d8\", \"e8f8\", \"e8g8\",\n    \"e8h8\", \"f8f1\", \"f8f2\", \"f8a3\", \"f8f3\", \"f8b4\", \"f8f4\", \"f8c5\", \"f8f5\",\n    \"f8d6\", \"f8e6\", \"f8f6\", \"f8g6\", \"f8h6\", \"f8d7\", \"f8e7\", \"f8f7\", \"f8g7\",\n    \"f8h7\", \"f8a8\", \"f8b8\", \"f8c8\", \"f8d8\", \"f8e8\", \"f8g8\", \"f8h8\", \"g8g1\",\n    \"g8a2\", \"g8g2\", \"g8b3\", \"g8g3\", \"g8c4\", \"g8g4\", \"g8d5\", \"g8g5\", \"g8e6\",\n    \"g8f6\", \"g8g6\", \"g8h6\", \"g8e7\", \"g8f7\", \"g8g7\", \"g8h7\", \"g8a8\", \"g8b8\",\n    \"g8c8\", \"g8d8\", \"g8e8\", \"g8f8\", \"g8h8\", \"h8a1\", \"h8h1\", \"h8b2\", \"h8h2\",\n    \"h8c3\", \"h8h3\", \"h8d4\", \"h8h4\", \"h8e5\", \"h8h5\", \"h8f6\", \"h8g6\", \"h8h6\",\n    \"h8f7\", \"h8g7\", \"h8h7\", \"h8a8\", \"h8b8\", \"h8c8\", \"h8d8\", \"h8e8\", \"h8f8\",\n    \"h8g8\", \"a7a8q\", \"a7a8r\", \"a7a8b\", \"a7b8q\", \"a7b8r\", \"a7b8b\", \"b7a8q\",\n    \"b7a8r\", \"b7a8b\", \"b7b8q\", \"b7b8r\", \"b7b8b\", \"b7c8q\", \"b7c8r\", \"b7c8b\",\n    \"c7b8q\", \"c7b8r\", \"c7b8b\", \"c7c8q\", \"c7c8r\", \"c7c8b\", \"c7d8q\", \"c7d8r\",\n    \"c7d8b\", \"d7c8q\", \"d7c8r\", \"d7c8b\", \"d7d8q\", \"d7d8r\", \"d7d8b\", \"d7e8q\",\n    \"d7e8r\", \"d7e8b\", \"e7d8q\", \"e7d8r\", \"e7d8b\", \"e7e8q\", \"e7e8r\", \"e7e8b\",\n    \"e7f8q\", \"e7f8r\", \"e7f8b\", \"f7e8q\", \"f7e8r\", \"f7e8b\", \"f7f8q\", \"f7f8r\",\n    \"f7f8b\", \"f7g8q\", \"f7g8r\", \"f7g8b\", \"g7f8q\", \"g7f8r\", \"g7f8b\", \"g7g8q\",\n    \"g7g8r\", \"g7g8b\", \"g7h8q\", \"g7h8r\", \"g7h8b\", \"h7g8q\", \"h7g8r\", \"h7g8b\",\n    \"h7h8q\", \"h7h8r\", \"h7h8b\"\n]\n\n\nclass Board:\n    def __init__(self):\n        self.clear_board()\n\n    def clear_board(self):\n        self.board = []\n        for rank in range(8):\n            self.board.append(list(\".\" * 8))\n        self.reps = 0\n\n    def describe(self):\n        s = []\n        for rank in range(8):\n            s.append(\"\".join(self.board[rank]))\n        s.append(\"reps {}  \".format(self.reps))\n        return s\n\n\nclass TrainingStep:\n    def __init__(self, version):\n        self.version = version\n        # Construct a fake parser just to get access to it's variables\n        self.parser = chunkparser.ChunkParser(chunkparser.ChunkDataSrc([]),\n                                              workers=1)\n        self.NUM_HIST = 8\n        self.NUM_PIECE_TYPES = 6\n        self.V3_NUM_PLANES = self.NUM_PIECE_TYPES * 2 + 1  # = 13 (6*2 us/them pieces, rep1 (no rep2))\n        self.NUM_PLANES = self.V3_NUM_PLANES\n        self.NUM_REALS = 7  # 4 castling, 1 color, 1 50rule, 1 movecount\n        self.NUM_OUTPUTS = 2  # policy, value\n        self.NUM_PLANES_BYTES = self.NUM_PLANES * 4\n        self.NUM_PLANES_BYTES = self.NUM_PLANES * 4\n        self.NUM_PLANES_BYTES = self.NUM_PLANES * 4\n\n        self.V3_NUM_POLICY_MOVES = 1858  # (7432 bytes)\n        self.NUM_POLICY_MOVES = self.V3_NUM_POLICY_MOVES\n\n        self.init_structs()\n        self.init_move_map()\n        self.history = []\n        self.probs = []\n        for history in range(self.NUM_HIST):\n            self.history.append(Board())\n        self.us_ooo = 0\n        self.us_oo = 0\n        self.them_ooo = 0\n        self.them_oo = 0\n        self.us_black = 0\n        self.rule50_count = 0\n        self.winner = None\n        self.q = None\n\n    def init_structs(self):\n        self.v4_struct = self.parser.v4_struct\n        self.this_struct = self.v4_struct\n\n    def init_move_map(self):\n        self.new_white_move_map = defaultdict(lambda: -1)\n        self.new_black_move_map = defaultdict(lambda: -1)\n        self.old_rev_move_map = {}\n        self.new_rev_white_move_map = {}\n        self.new_rev_black_move_map = {}\n\n        for idx, m in enumerate(MOVES):\n            self.new_white_move_map[m] = idx\n            self.new_rev_white_move_map[idx] = m\n            m_black = m.translate(str.maketrans(\"12345678\", \"87654321\"))\n            self.new_black_move_map[m_black] = idx\n            self.new_rev_black_move_map[idx] = m_black\n\n    def clear_hist(self):\n        for hist in range(self.NUM_HIST):\n            self.history.clear_board()\n\n    def update_board(self, hist, piece, bit_board):\n        \"\"\"\n            Update the ASCII board representation\n        \"\"\"\n        for r in range(8):\n            for f in range(8):\n                # Note: Using 8-1-f because both the text and binary have the\n                # column bits reversed fhom what this code expects\n                if bit_board & (1 << (r * 8 + (8 - 1 - f))):\n                    assert (self.history[hist].board[r][f] == \".\")\n                    self.history[hist].board[r][f] = piece\n\n    def describe(self):\n        s = \"\"\n        if self.us_black:\n            s += \"us = Black\"\n        else:\n            s += \"us = White\"\n        if self.winner == 1:\n            s += \" won\\n\"\n        elif self.winner == -1:\n            s += \" lost\\n\"\n        elif self.winner == 0:\n            s += \" draw\\n\"\n        else:\n            raise Exception(\"Invalid winner: {}\".format(self.winner))\n        s += \"Root Q = {} (diff to result: {}) \\n\".format(\n            self.root_q, abs(self.winner - self.root_q))\n        s += \"Best Q = {} (diff to result: {}) \\n\".format(\n            self.best_q, abs(self.winner - self.best_q))\n        if self.us_black:\n            s += \"(Note the black pieces are CAPS, black moves up, but A1 is in lower left)\\n\"\n        s += \"rule50_count {} b_ooo b_oo, w_ooo, w_oo {} {} {} {}\\n\".format(\n            self.rule50_count, self.us_ooo, self.us_oo, self.them_ooo,\n            self.them_oo)\n        s += \"  abcdefgh\\n\"\n        rank_strings = [[]]\n        for rank in reversed(range(8)):\n            rank_strings[0].append(\"{}\".format(rank + 1))\n        rank_strings[0].append(\"  \")\n        for hist in range(self.NUM_HIST):\n            rank_strings.append(self.history[hist].describe())\n        for hist in range(self.NUM_HIST + 1):\n            for rank in range(8 + 1):\n                #if hist == 8 and rank == 0: continue\n                s += rank_strings[rank][hist] + \" \"\n            s += \"\\n\"\n        sum = 0.0\n        top_moves = {}\n        for idx, prob in enumerate(self.probs):\n            # Include all moves with at least 1 visit.\n            condition = prob > 0 if self.version == 3 else prob >= 0\n            if condition:\n                top_moves[idx] = prob\n            sum += prob\n        for idx, prob in sorted(top_moves.items(), key=lambda x: -x[1]):\n            s += \"{} {:4.1f}%\\n\".format(self.new_rev_white_move_map[idx],\n                                        prob * 100)\n        #print(\"debug prob sum\", sum, \"cnt\", len(self.probs))\n        return s\n\n    def update_reals(self, text_item):\n        self.us_ooo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 0])\n        self.us_oo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 1])\n        self.them_ooo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 2])\n        self.them_oo = int(text_item[self.NUM_HIST * self.NUM_PLANES + 3])\n        self.us_black = int(text_item[self.NUM_HIST * self.NUM_PLANES + 4])\n        self.rule50_count = min(\n            int(text_item[self.NUM_HIST * self.NUM_PLANES + 5]), 255)\n        # should be around 99-102ish\n        assert self.rule50_count < 105\n\n    def flip_single_v1_plane(self, plane):\n        # Split hexstring into bytes (2 ascii chars), reverse, rejoin\n        # This causes a vertical flip\n        return \"\".join(\n            [plane[x:x + 2] for x in reversed(range(0, len(plane), 2))])\n\n    def display_v4(self, ply, content):\n        (ver, probs, planes, us_ooo, us_oo, them_ooo, them_oo, us_black,\n         rule50_count, move_count, winner, root_q, best_q, root_d,\n         best_d) = self.this_struct.unpack(content)\n        assert self.version == int.from_bytes(ver, byteorder=\"little\")\n        # Enforce move_count to 0\n        move_count = 0\n        # Unpack planes.\n        for hist in range(self.NUM_HIST):\n            for idx, piece in enumerate(PIECES):\n                start = hist * self.NUM_PLANES * 8 + idx * 8\n                end = start + 8\n                self.update_board(\n                    hist, piece,\n                    int.from_bytes(planes[start:end], byteorder=\"big\"))\n            if planes[hist * self.NUM_PLANES * 8 +\n                      12 * 8:hist * self.NUM_PLANES * 8 + 12 * 8 +\n                      8] != struct.pack('II', 0, 0):\n                self.history[hist].reps = 1\n                assert planes[hist * self.NUM_PLANES * 8 +\n                              12 * 8:hist * self.NUM_PLANES * 8 + 12 * 8 +\n                              8] == struct.pack('II', 0xffffffff, 0xffffffff)\n        self.us_ooo = us_ooo\n        self.us_oo = us_oo\n        self.them_ooo = them_ooo\n        self.them_oo = them_oo\n        self.us_black = us_black\n        self.rule50_count = rule50_count\n        self.winner = winner\n        self.root_q = root_q\n        self.best_q = best_q\n        for idx in range(0, len(probs), 4):\n            self.probs.append(struct.unpack(\"f\", probs[idx:idx + 4])[0])\n        print(\"ply {} move {} (Not actually part of training data)\".format(\n            ply + 1, (ply + 2) // 2))\n        print(self.describe())\n\n\ndef main(args):\n    for filename in args.files:\n        #print(\"Parsing {}\".format(filename))\n        with gzip.open(filename, 'rb') as f:\n            chunkdata = f.read()\n            version = chunkdata[0:4]\n            if version in {VERSION4, VERSION3}:\n                if version == VERSION3:\n                    record_size = V3_BYTES\n                else:\n                    record_size = V4_BYTES\n                for i in range(0, len(chunkdata), record_size):\n                    ts = TrainingStep(4 if version == VERSION4 else 3)\n                    record = chunkdata[i:i + record_size]\n                    if chunkdata[0:4] == VERSION3:\n                        record += 16 * b'\\x00'\n                    ts.display_v4(i // record_size, record)\n            else:\n                print(\"Invalid version\")\n\n\nif __name__ == '__main__':\n    usage_str = \"\"\"\nParse training files and display them.\"\"\"\n\n    parser = argparse.ArgumentParser(\n        formatter_class=argparse.RawDescriptionHelpFormatter,\n        description=usage_str)\n    parser.add_argument(\"files\", type=str, nargs=\"+\", help=\"training*.gz\")\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "tf/lc0_az_policy_map.py",
    "content": "#!/usr/bin/env python3\nimport sys\nimport numpy as np\nfrom policy_index import policy_index\n\ncolumns = 'abcdefgh'\nrows = '12345678'\npromotions = 'rbq'  # N is encoded as normal move\n\ncol_index = {columns[i]: i for i in range(len(columns))}\nrow_index = {rows[i]: i for i in range(len(rows))}\n\n\ndef index_to_position(x):\n    return columns[x[0]] + rows[x[1]]\n\n\ndef position_to_index(p):\n    return col_index[p[0]], row_index[p[1]]\n\n\ndef valid_index(i):\n    if i[0] > 7 or i[0] < 0:\n        return False\n    if i[1] > 7 or i[1] < 0:\n        return False\n    return True\n\n\ndef queen_move(start, direction, steps):\n    i = position_to_index(start)\n    dir_vectors = {\n        'N': (0, 1),\n        'NE': (1, 1),\n        'E': (1, 0),\n        'SE': (1, -1),\n        'S': (0, -1),\n        'SW': (-1, -1),\n        'W': (-1, 0),\n        'NW': (-1, 1)\n    }\n    v = dir_vectors[direction]\n    i = i[0] + v[0] * steps, i[1] + v[1] * steps\n    if not valid_index(i):\n        return None\n    return index_to_position(i)\n\n\ndef knight_move(start, direction, steps):\n    i = position_to_index(start)\n    dir_vectors = {\n        'N': (1, 2),\n        'NE': (2, 1),\n        'E': (2, -1),\n        'SE': (1, -2),\n        'S': (-1, -2),\n        'SW': (-2, -1),\n        'W': (-2, 1),\n        'NW': (-1, 2)\n    }\n    v = dir_vectors[direction]\n    i = i[0] + v[0] * steps, i[1] + v[1] * steps\n    if not valid_index(i):\n        return None\n    return index_to_position(i)\n\n\ndef make_map(kind='matrix'):\n    # 56 planes of queen moves\n    moves = []\n    for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:\n        for steps in range(1, 8):\n            for r0 in rows:\n                for c0 in columns:\n                    start = c0 + r0\n                    end = queen_move(start, direction, steps)\n                    if end == None:\n                        moves.append('illegal')\n                    else:\n                        moves.append(start + end)\n\n    # 8 planes of knight moves\n    for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:\n        for r0 in rows:\n            for c0 in columns:\n                start = c0 + r0\n                end = knight_move(start, direction, 1)\n                if end == None:\n                    moves.append('illegal')\n                else:\n                    moves.append(start + end)\n\n    # 9 promotions\n    for direction in ['NW', 'N', 'NE']:\n        for promotion in promotions:\n            for r0 in rows:\n                for c0 in columns:\n                    # Promotion only in the second last rank\n                    if r0 != '7':\n                        moves.append('illegal')\n                        continue\n                    start = c0 + r0\n                    end = queen_move(start, direction, 1)\n                    if end == None:\n                        moves.append('illegal')\n                    else:\n                        moves.append(start + end + promotion)\n\n    for m in policy_index:\n        if m not in moves:\n            raise ValueError('Missing move: {}'.format(m))\n\n    az_to_lc0 = np.zeros((80 * 8 * 8, len(policy_index)), dtype=np.float32)\n    indices = []\n    legal_moves = 0\n    for e, m in enumerate(moves):\n        if m == 'illegal':\n            indices.append(-1)\n            continue\n        legal_moves += 1\n        # Check for missing moves\n        if m not in policy_index:\n            raise ValueError('Missing move: {}'.format(m))\n        i = policy_index.index(m)\n        indices.append(i)\n        az_to_lc0[e][i] = 1\n\n    assert legal_moves == len(policy_index)\n    assert np.sum(az_to_lc0) == legal_moves\n    for e in range(80 * 8 * 8):\n        for i in range(len(policy_index)):\n            pass\n    if kind == 'matrix':\n        return az_to_lc0\n    elif kind == 'index':\n        return indices\n\n\nif __name__ == \"__main__\":\n    # Generate policy map include file for lc0\n    if len(sys.argv) != 2:\n        raise ValueError(\n            \"Output filename is needed as a command line argument\")\n\n    az_to_lc0 = np.ravel(make_map('index'))\n    header = \\\n\"\"\"/*\n This file is part of Leela Chess Zero.\n Copyright (C) 2019 The LCZero Authors\n\n Leela Chess is free software: you can redistribute it and/or modify\n it under the terms of the GNU General Public License as published by\n the Free Software Foundation, either version 3 of the License, or\n (at your option) any later version.\n\n Leela Chess is distributed in the hope that it will be useful,\n but WITHOUT ANY WARRANTY; without even the implied warranty of\n MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n GNU General Public License for more details.\n\n You should have received a copy of the GNU General Public License\n along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\n */\n\n#pragma once\n\nnamespace lczero {\n\"\"\"\n    line_length = 12\n    with open(sys.argv[1], 'w') as f:\n        f.write(header + '\\n')\n        f.write('const short kConvPolicyMap[] = {\\\\\\n')\n        for e, i in enumerate(az_to_lc0):\n            if e % line_length == 0 and e > 0:\n                f.write('\\n')\n            f.write(str(i).rjust(5))\n            if e != len(az_to_lc0) - 1:\n                f.write(',')\n        f.write('};\\n\\n')\n        f.write('}  // namespace lczero')\n"
  },
  {
    "path": "tf/make_model.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport os\nimport yaml\nimport tfprocess\n\nargparser = argparse.ArgumentParser(description='Convert net to model.')\nargparser.add_argument('--start',\n                       type=int,\n                       default=0,\n                       help='Offset to set global_step to.')\nargparser.add_argument('--cfg',\n                       type=argparse.FileType('r'),\n                       help='yaml configuration with training parameters')\nargs = argparser.parse_args()\ncfg = yaml.safe_load(args.cfg.read())\nprint(yaml.dump(cfg, default_flow_style=False))\nSTART_FROM = args.start\n\ntfp = tfprocess.TFProcess(cfg)\ntfp.init_net()\ntfp.global_step.assign(START_FROM)\n\nroot_dir = os.path.join(cfg['training']['path'], cfg['name'])\nif not os.path.exists(root_dir):\n    os.makedirs(root_dir)\ntfp.manager.save(checkpoint_number=START_FROM)\nprint(\"Wrote model to {}\".format(tfp.manager.latest_checkpoint))\npath = os.path.join(tfp.root_dir, tfp.cfg['name'])\nleela_path = path + \"-\" + str(START_FROM)\nswa_path = path + \"-swa-\" + str(START_FROM)\ntfp.net.pb.training_params.training_steps = START_FROM\ntfp.save_leelaz_weights(leela_path)\nif tfp.swa_enabled:\n    tfp.save_swa_weights(swa_path)\n\n"
  },
  {
    "path": "tf/model_to_net.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport os\nimport yaml\nimport tfprocess\n\nargparser = argparse.ArgumentParser(description='Convert model to net.')\nargparser.add_argument('--cfg',\n                       type=argparse.FileType('r'),\n                       help='yaml configuration with training parameters')\nargs = argparser.parse_args()\ncfg = yaml.safe_load(args.cfg.read())\nprint(yaml.dump(cfg, default_flow_style=False))\n\ntfp = tfprocess.TFProcess(cfg)\ntfp.init_net()\n\ntfp.restore()\n\nroot_dir = os.path.join(cfg['training']['path'], cfg['name'])\nif not os.path.exists(root_dir):\n    os.makedirs(root_dir)\npath = os.path.join(tfp.root_dir, tfp.cfg['name'])\nsteps = tfp.global_step.read_value().numpy()\nleela_path = path + \"-\" + str(steps)\nswa_path = path + \"-swa-\" + str(steps)\ntfp.net.pb.training_params.training_steps = steps\ntfp.save_leelaz_weights(leela_path)\nif tfp.swa_enabled:\n    tfp.save_swa_weights(swa_path)\n\n"
  },
  {
    "path": "tf/net.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport gzip\nimport os\nimport numpy as np\nimport proto.net_pb2 as pb\n\nLC0_MAJOR = 0\nLC0_MINOR = 21\nLC0_MINOR_WITH_INPUT_TYPE_3 = 25\nLC0_MINOR_WITH_INPUT_TYPE_4 = 26\nLC0_MINOR_WITH_INPUT_TYPE_5 = 27\nLC0_MINOR_WITH_MISH = 29\nLC0_MINOR_WITH_ATTN_BODY = 30\nLC0_PATCH = 0\nWEIGHTS_MAGIC = 0x1c0\n\n\ndef nested_getattr(obj, attr):\n    attributes = attr.split(\".\")\n    for a in attributes:\n        obj = getattr(obj, a)\n    return obj\n\n\nclass Net:\n\n    def __init__(self,\n                 net=pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT,\n                 input=pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE,\n                 value=pb.NetworkFormat.VALUE_CLASSICAL,\n                 policy=pb.NetworkFormat.POLICY_CLASSICAL,\n                 moves_left=pb.NetworkFormat.MOVES_LEFT_V1):\n\n        if net == pb.NetworkFormat.NETWORK_SE:\n            net = pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT\n        if net == pb.NetworkFormat.NETWORK_CLASSICAL:\n            net = pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT\n\n        self.pb = pb.Net()\n        self.pb.magic = WEIGHTS_MAGIC\n        self.pb.min_version.major = LC0_MAJOR\n        self.pb.min_version.minor = LC0_MINOR\n        self.pb.min_version.patch = LC0_PATCH\n        self.pb.format.weights_encoding = pb.Format.LINEAR16\n\n        self.weights = []\n\n        self.set_networkformat(net)\n        self.pb.format.network_format.input = input\n        self.set_policyformat(policy)\n        self.set_valueformat(value)\n        self.set_movesleftformat(moves_left)\n        self.set_defaultactivation(pb.NetworkFormat.DEFAULT_ACTIVATION_RELU)\n\n    def set_networkformat(self, net):\n        self.pb.format.network_format.network = net\n        if net == pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT \\\n                and self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY:\n            self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY\n\n    def set_policyformat(self, policy):\n        self.pb.format.network_format.policy = policy\n\n    def set_headcount(self, headcount):\n        self.pb.weights.headcount = headcount\n\n    def set_pol_headcount(self, headcount):\n        self.pb.weights.pol_headcount = headcount\n\n    def set_valueformat(self, value):\n        self.pb.format.network_format.value = value\n\n        # OutputFormat is for search to know which kind of value the net returns.\n        if value == pb.NetworkFormat.VALUE_WDL:\n            self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_WDL\n        else:\n            self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_CLASSICAL\n\n    def set_movesleftformat(self, moves_left):\n        self.pb.format.network_format.moves_left = moves_left\n\n    def set_input(self, input_format):\n        self.pb.format.network_format.input = input_format\n        if input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2 or input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON:\n            self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_5\n        elif input_format >= pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES:\n            self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_4\n        # Input type 2 was available before 3, but it was buggy, so also limit it to same version as 3.\n        elif input_format != pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE:\n            self.pb.min_version.minor = LC0_MINOR_WITH_INPUT_TYPE_3\n\n    def set_defaultactivation(self, activation):\n        self.pb.format.network_format.default_activation = activation\n        if activation == pb.NetworkFormat.DEFAULT_ACTIVATION_MISH:\n            if self.pb.min_version.minor < LC0_MINOR_WITH_MISH:\n                self.pb.min_version.minor = LC0_MINOR_WITH_MISH\n\n    def set_smolgen_activation(self, activation):\n        self.pb.format.network_format.smolgen_activation = activation\n        if self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY:\n            self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY\n        return None\n\n    def set_ffn_activation(self, activation):\n        self.pb.format.network_format.ffn_activation = activation\n        if self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY:\n            self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY\n        return None\n\n    def activation(self, name):\n        if name == \"relu\":\n            return pb.NetworkFormat.ACTIVATION_RELU\n        elif name == \"tanh\":\n            return pb.NetworkFormat.ACTIVATION_TANH\n        elif name == \"sigmoid\":\n            return pb.NetworkFormat.ACTIVATION_SIGMOID\n        elif name == \"softmax\":\n            return pb.NetworkFormat.ACTIVATION_SOFTMAX\n        elif name == \"selu\":\n            return pb.NetworkFormat.ACTIVATION_SELU\n        elif name == \"mish\":\n            return pb.NetworkFormat.ACTIVATION_MISH\n        elif name == \"swish\":\n            return pb.NetworkFormat.ACTIVATION_SWISH\n        elif name == \"relu_2\" or name == \"sqrrelu\":\n            return pb.NetworkFormat.ACTIVATION_RELU_2\n        elif name == \"default\":\n            return pb.NetworkFormat.ACTIVATION_DEFAULT\n        else:\n            return pb.NetworkFormat.ACTIVATION_NONE\n\n    def get_weight_amounts(self):\n        value_weights = 8\n        policy_weights = 6\n        head_weights = value_weights + policy_weights\n        if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:\n            # Batch norm gammas in head convolutions.\n            head_weights += 2\n        if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:\n            return {\"input\": 5, \"residual\": 14, \"head\": head_weights}\n        else:\n            return {\"input\": 4, \"residual\": 8, \"head\": head_weights}\n\n    def fill_layer_v2(self, layer, params):\n        \"\"\"Normalize and populate 16bit layer in protobuf\"\"\"\n        params = params.flatten().astype(np.float32)\n        layer.min_val = 0 if len(params) == 1 else float(np.min(params))\n        layer.max_val = 1 if len(params) == 1 and np.max(\n            params) == 0 else float(np.max(params))\n        if layer.max_val == layer.min_val:\n            # Avoid division by zero if max == min.\n            params = (params - layer.min_val)\n        else:\n            params = (params - layer.min_val) / (layer.max_val - layer.min_val)\n        params *= 0xffff\n        params = np.round(params)\n        layer.params = params.astype(np.uint16).tobytes()\n\n    def fill_layer(self, layer, weights):\n        \"\"\"Normalize and populate 16bit layer in protobuf\"\"\"\n        params = np.array(weights.pop(), dtype=np.float32)\n        layer.min_val = 0 if len(params) == 1 else float(np.min(params))\n        layer.max_val = 1 if len(params) == 1 and np.max(\n            params) == 0 else float(np.max(params))\n        if layer.max_val == layer.min_val:\n            # Avoid division by zero if max == min.\n            params = (params - layer.min_val)\n        else:\n            params = (params - layer.min_val) / (layer.max_val - layer.min_val)\n        params *= 0xffff\n        params = np.round(params)\n        layer.params = params.astype(np.uint16).tobytes()\n\n    def fill_conv_block(self, convblock, weights, gammas):\n        \"\"\"Normalize and populate 16bit convblock in protobuf\"\"\"\n        if gammas:\n            self.fill_layer(convblock.bn_stddivs, weights)\n            self.fill_layer(convblock.bn_means, weights)\n            self.fill_layer(convblock.bn_betas, weights)\n            self.fill_layer(convblock.bn_gammas, weights)\n            self.fill_layer(convblock.weights, weights)\n        else:\n            self.fill_layer(convblock.bn_stddivs, weights)\n            self.fill_layer(convblock.bn_means, weights)\n            self.fill_layer(convblock.biases, weights)\n            self.fill_layer(convblock.weights, weights)\n\n    def fill_plain_conv(self, convblock, weights):\n        \"\"\"Normalize and populate 16bit convblock in protobuf\"\"\"\n        self.fill_layer(convblock.biases, weights)\n        self.fill_layer(convblock.weights, weights)\n\n    def fill_se_unit(self, se_unit, weights):\n        self.fill_layer(se_unit.b2, weights)\n        self.fill_layer(se_unit.w2, weights)\n        self.fill_layer(se_unit.b1, weights)\n        self.fill_layer(se_unit.w1, weights)\n\n    def denorm_layer_v2(self, layer):\n        \"\"\"Denormalize a layer from protobuf\"\"\"\n        params = np.frombuffer(layer.params, np.uint16).astype(np.float32)\n        params /= 0xffff\n        return params * (layer.max_val - layer.min_val) + layer.min_val\n\n    def denorm_layer(self, layer, weights):\n        weights.insert(0, self.denorm_layer_v2(layer))\n\n    def denorm_conv_block(self, convblock, weights):\n        \"\"\"Denormalize a convblock from protobuf\"\"\"\n        se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT\n\n        if se:\n            self.denorm_layer(convblock.bn_stddivs, weights)\n            self.denorm_layer(convblock.bn_means, weights)\n            self.denorm_layer(convblock.bn_betas, weights)\n            self.denorm_layer(convblock.bn_gammas, weights)\n            self.denorm_layer(convblock.weights, weights)\n        else:\n            self.denorm_layer(convblock.bn_stddivs, weights)\n            self.denorm_layer(convblock.bn_means, weights)\n            self.denorm_layer(convblock.biases, weights)\n            self.denorm_layer(convblock.weights, weights)\n\n    def denorm_plain_conv(self, convblock, weights):\n        \"\"\"Denormalize a plain convolution from protobuf\"\"\"\n        self.denorm_layer(convblock.biases, weights)\n        self.denorm_layer(convblock.weights, weights)\n\n    def denorm_se_unit(self, convblock, weights):\n        \"\"\"Denormalize SE-unit from protobuf\"\"\"\n        se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT\n\n        assert se\n\n        self.denorm_layer(convblock.b2, weights)\n        self.denorm_layer(convblock.w2, weights)\n        self.denorm_layer(convblock.b1, weights)\n        self.denorm_layer(convblock.w1, weights)\n\n    def save_txt(self, filename):\n        \"\"\"Save weights as txt file\"\"\"\n        weights = self.get_weights()\n\n        if len(filename.split('.')) == 1:\n            filename += \".txt.gz\"\n\n        # Legacy .txt files are version 2, SE is version 3.\n\n        version = 2\n        if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:\n            version = 3\n\n        if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:\n            version = 4\n\n        with gzip.open(filename, 'wb') as f:\n            f.write(\"{}\\n\".format(version).encode('utf-8'))\n            for row in weights:\n                f.write(\n                    (\" \".join(map(str, row.tolist())) + \"\\n\").encode('utf-8'))\n\n        size = os.path.getsize(filename) / 1024**2\n        print(\"saved as '{}' {}M\".format(filename, round(size, 2)))\n\n    def save_proto(self, filename):\n        \"\"\"Save weights gzipped protobuf file\"\"\"\n        if len(filename.split('.')) == 1:\n            filename += \".pb.gz\"\n\n        with gzip.open(filename, 'wb') as f:\n            data = self.pb.SerializeToString()\n            f.write(data)\n\n        size = os.path.getsize(filename) / 1024**2\n        print(\"Weights saved as '{}' {}M\".format(filename, round(size, 2)))\n\n    def tf_name_to_pb_name(self, name):\n        \"\"\"Given Tensorflow variable name returns the protobuf name and index\n        of residual block if weight belong in a residual block.\"\"\"\n\n        def convblock_to_bp(w):\n            w = w.split(':')[0]\n            d = {\n                'kernel': 'weights',\n                'gamma': 'bn_gammas',\n                'beta': 'bn_betas',\n                'moving_mean': 'bn_means',\n                'moving_variance': 'bn_stddivs',\n                'bias': 'biases'\n            }\n\n            return d[w]\n\n        def se_to_bp(l, w):\n            if l == 'dense1':\n                n = 1\n            elif l == 'dense2':\n                n = 2\n            else:\n                raise ValueError('Unable to decode SE-weight {}/{}'.format(\n                    l, w))\n            w = w.split(':')[0]\n            d = {'kernel': 'w', 'bias': 'b'}\n\n            return d[w] + str(n)\n\n        def value_to_bp(l, w):\n            if l == 'embedding':\n                n = ''\n            elif l == 'dense1':\n                n = 1\n            elif l == 'dense2':\n                n = 2\n            else:\n                raise ValueError('Unable to decode value weight {}/{}'.format(\n                    l, w))\n            w = w.split(':')[0]\n            d = {'kernel': 'ip{}_val_w', 'bias': 'ip{}_val_b'}\n\n            return d[w].format(n)\n\n        def conv_policy_to_bp(w):\n            w = w.split(':')[0]\n            d = {'kernel': 'ip_pol_w', 'bias': 'ip_pol_b'}\n\n            return d[w]\n\n        def attn_pol_to_bp(l, w):\n            if l == 'wq':\n                n = 2\n            elif l == 'wk':\n                n = 3\n            elif l == 'ppo':\n                n = 4\n            else:\n                raise ValueError(\n                    'Unable to decode attn_policy weight {}/{}'.format(l, w))\n            w = w.split(':')[0]\n            d = {'kernel': 'ip{}_pol_w', 'bias': 'ip{}_pol_b'}\n\n            return d[w].format(n)\n\n        def encoder_to_bp(l, w):\n            if l == 'ln1':\n                n = 1\n            elif l == 'ln2':\n                n = 2\n            else:\n                raise ValueError(\n                    'Unable to decode encoder weight {}/{}'.format(l, w))\n            w = w.split(':')[0]\n            d = {'gamma': 'ln{}_gammas', 'beta': 'ln{}_betas'}\n\n            return d[w].format(n)\n\n        def mha_to_bp(l, w):\n            s = ''\n            if l.startswith('dense'):\n                s = 'dense'\n            elif l.startswith('w'):\n                s = l[1]\n            else:\n                raise ValueError('Unable to decode mha weight {}/{}'.format(\n                    l, w))\n            w = w.split(':')[0]\n            d = {'kernel': '{}_w', 'bias': '{}_b'}\n\n            return d[w].format(s)\n\n        def mha_smolgen_to_bp(l, w):\n            s = {\n                'compress': 'compress',\n                'hidden1_dense': 'dense1_{}',\n                'hidden1_ln': 'ln1_{}',\n                'gen_from': 'dense2_{}',\n                'gen_from_ln': 'ln2_{}'\n            }\n            if s[l] is None:\n                raise ValueError(\n                    'Unable to decode mha smolgen weight {}/{}'.format(l, w))\n            w = w.split(':')[0]\n            d = {\n                'kernel': 'w',\n                'bias': 'b',\n                'gamma': 'gammas',\n                'beta': 'betas'\n            }\n\n            return s[l].format(d[w])\n\n        def ffn_to_bp(l, w):\n            w = w.split(':')[0]\n            d = {'kernel': '{}_w', 'bias': '{}_b'}\n\n            return d[w].format(l)\n\n        def moves_left_to_bp(l, w):\n            if l == 'embedding':\n                n = ''\n            elif l == 'dense1':\n                n = 1\n            elif l == 'dense2':\n                n = 2\n            else:\n                raise ValueError(\n                    'Unable to decode moves_left weight {}/{}'.format(l, w))\n            w = w.split(':')[0]\n            d = {'kernel': 'ip{}_mov_w', 'bias': 'ip{}_mov_b'}\n\n            return d[w].format(n)\n\n        layers = name.split('/')\n        base_layer = layers[0]\n        weights_name = layers[-1]\n        pb_name = None\n        block = None\n        encoder_block = None\n        pol_encoder_block = None\n\n        if base_layer == 'input':\n            pb_name = 'input.' + convblock_to_bp(weights_name)\n        elif base_layer == 'policy1':\n            pb_name = 'policy1.' + convblock_to_bp(weights_name)\n        elif base_layer == 'policy':\n            if 'dense' in layers[1]:\n                pb_name = conv_policy_to_bp(weights_name)\n            elif layers[1] == 'embedding':\n                if layers[2].split(':')[0] == 'kernel':\n                    pb_name = 'ip_pol_w'\n                else:\n                    pb_name = 'ip_pol_b'\n            elif layers[1] == 'attention':\n                pb_name = attn_pol_to_bp(layers[2], weights_name)\n            elif layers[1].startswith('enc_layer_'):\n                pol_encoder_block = int(layers[1].split('_')[2]) - 1\n                if layers[2] == 'mha':\n                    pb_name = 'mha.' + mha_to_bp(layers[3], weights_name)\n                elif layers[2] == 'ffn':\n                    pb_name = 'ffn.' + ffn_to_bp(layers[3], weights_name)\n                else:\n                    pb_name = encoder_to_bp(layers[2], weights_name)\n            else:\n                pb_name = 'policy.' + convblock_to_bp(weights_name)\n        elif base_layer == 'value':\n            if 'dense' in layers[1] or 'embedding' in layers[1]:\n                pb_name = value_to_bp(layers[1], weights_name)\n            else:\n                pb_name = 'value.' + convblock_to_bp(weights_name)\n        elif base_layer == 'moves_left':\n            if 'dense' in layers[1] or 'embedding' in layers[1]:\n                pb_name = moves_left_to_bp(layers[1], weights_name)\n            else:\n                pb_name = 'moves_left.' + convblock_to_bp(weights_name)\n        elif base_layer.startswith('residual'):\n            block = int(base_layer.split('_')[1]) - 1  # 1 indexed\n            if layers[1] == '1':\n                pb_name = 'conv1.' + convblock_to_bp(weights_name)\n            elif layers[1] == '2':\n                pb_name = 'conv2.' + convblock_to_bp(weights_name)\n            elif layers[1] == 'se':\n                pb_name = 'se.' + se_to_bp(layers[-2], weights_name)\n        elif base_layer.startswith('encoder'):\n            encoder_block = int(base_layer.split('_')[1]) - 1\n            if layers[1] == 'mha':\n                if layers[2] == 'smolgen':\n                    pb_name = 'mha.smolgen.' + mha_smolgen_to_bp(\n                        layers[3], weights_name)\n                else:\n                    pb_name = 'mha.' + mha_to_bp(layers[2], weights_name)\n            elif layers[1] == 'ffn':\n                pb_name = 'ffn.' + ffn_to_bp(layers[2], weights_name)\n            else:\n                pb_name = encoder_to_bp(layers[1], weights_name)\n        elif base_layer == 'embedding':\n            if layers[1] == 'mult_gate' or layers[1] == 'add_gate':\n                if layers[2].split(':')[0] == 'gate':\n                    pb_name = 'ip_{}'.format(layers[1])\n            elif layers[1].split(':')[0] == 'kernel':\n                pb_name = 'ip_emb_w'\n            elif layers[1].split(':')[0] == 'bias':\n                pb_name = 'ip_emb_b'\n        elif base_layer == 'smol_weight_gen':\n            if layers[1].split(':')[0] == 'kernel':\n                pb_name = 'smolgen_w'\n            else:\n                pb_name = 'smolgen_b'\n\n        return (pb_name, block, pol_encoder_block, encoder_block)\n\n    def get_weights_v2(self, names):\n        # `names` is a list of Tensorflow tensor names to get from the protobuf.\n        # Returns list of [Tensor name, Tensor weights].\n        tensors = {}\n\n        for tf_name in names:\n            name = tf_name\n            if 'stddev' in name:\n                # Get variance instead of stddev.\n                name = name.replace('stddev', 'variance')\n            if 'renorm' in name:\n                # Renorm variables are not populated.\n                continue\n            if 'headcount' in tf_name:\n                # headcount is set with set_headcount()\n                continue\n\n            pb_name, block, pol_encoder_block, encoder_block = self.tf_name_to_pb_name(\n                name)\n\n            if pb_name is None:\n                raise ValueError(\n                    \"Don't know where to store weight in protobuf: {}\".format(\n                        name))\n\n            if block is None:\n                if pol_encoder_block is not None:\n                    pb_weights = self.pb.weights.pol_encoder[pol_encoder_block]\n                elif encoder_block is not None:\n                    pb_weights = self.pb.weights.encoder[encoder_block]\n                else:\n                    pb_weights = self.pb.weights\n            else:\n                pb_weights = self.pb.weights.residual[block]\n\n            w = self.denorm_layer_v2(nested_getattr(pb_weights, pb_name))\n\n            # Only variance is stored in the protobuf.\n            if 'stddev' in tf_name:\n                w = np.sqrt(w + 1e-5)\n\n            tensors[tf_name] = w\n        return tensors\n\n    def get_weights(self):\n        \"\"\"Returns the weights as floats per layer\"\"\"\n        se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT\n        if self.weights == []:\n            self.denorm_layer(self.pb.weights.ip2_val_b, self.weights)\n            self.denorm_layer(self.pb.weights.ip2_val_w, self.weights)\n            self.denorm_layer(self.pb.weights.ip1_val_b, self.weights)\n            self.denorm_layer(self.pb.weights.ip1_val_w, self.weights)\n            self.denorm_conv_block(self.pb.weights.value, self.weights)\n\n            if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:\n                self.denorm_plain_conv(self.pb.weights.policy, self.weights)\n                self.denorm_conv_block(self.pb.weights.policy1, self.weights)\n            else:\n                self.denorm_layer(self.pb.weights.ip_pol_b, self.weights)\n                self.denorm_layer(self.pb.weights.ip_pol_w, self.weights)\n                self.denorm_conv_block(self.pb.weights.policy, self.weights)\n\n            for res in reversed(self.pb.weights.residual):\n                if se:\n                    self.denorm_se_unit(res.se, self.weights)\n                self.denorm_conv_block(res.conv2, self.weights)\n                self.denorm_conv_block(res.conv1, self.weights)\n\n            self.denorm_conv_block(self.pb.weights.input, self.weights)\n\n        return self.weights\n\n    def filters(self):\n        layer = self.pb.weights.input.bn_means\n        params = np.frombuffer(layer.params, np.uint16).astype(np.float32)\n        return len(params)\n\n    def blocks(self):\n        return len(self.pb.weights.residual)\n\n    def print_stats(self):\n        print(\"Blocks: {}\".format(self.blocks()))\n        print(\"Filters: {}\".format(self.filters()))\n        print_pb_stats(self.pb)\n        print()\n\n    def parse_proto(self, filename):\n        with gzip.open(filename, 'rb') as f:\n            self.pb = self.pb.FromString(f.read())\n        # Populate policyFormat and valueFormat fields in old protobufs\n        # without these fields.\n        if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE:\n            self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)\n            self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL)\n            self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL)\n            self.set_movesleftformat(pb.NetworkFormat.MOVES_LEFT_NONE)\n        elif self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_CLASSICAL:\n            self.set_networkformat(\n                pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT)\n            self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL)\n            self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL)\n            self.set_movesleftformat(pb.NetworkFormat.MOVES_LEFT_NONE)\n\n    def parse_txt(self, filename):\n        weights = []\n\n        with open(filename, 'r') as f:\n            try:\n                version = int(f.readline()[0])\n            except:\n                raise ValueError('Unable to read version.')\n            for e, line in enumerate(f):\n                weights.append(list(map(float, line.split(' '))))\n\n        if version == 3:\n            self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)\n\n        if version == 4:\n            self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)\n            self.set_policyformat(pb.NetworkFormat.POLICY_CONVOLUTION)\n\n        self.fill_net(weights)\n\n    def fill_net_v2(self, all_weights):\n        # all_weights is array of [name of weight, numpy array of weights].\n        self.pb.format.weights_encoding = pb.Format.LINEAR16\n\n        has_renorm = any('renorm' in w[0] for w in all_weights)\n        weight_names = [w[0] for w in all_weights]\n\n        del self.pb.weights.residual[:]\n\n        for name, weights in all_weights:\n            layers = name.split('/')\n            weights_name = layers[-1]\n            if weights.ndim == 4:\n                # Convolution weights need a transpose\n                #\n                # TF\n                # [filter_height, filter_width, in_channels, out_channels]\n                #\n                # Leela\n                # [output, input, filter_size, filter_size]\n                weights = np.transpose(weights, axes=[3, 2, 0, 1])\n            elif weights.ndim == 2:\n                # Fully connected layers are [in, out] in TF\n                #\n                # [out, in] in Leela\n                #\n                weights = np.transpose(weights, axes=[1, 0])\n\n            if 'renorm' in name:\n                # Batch renorm has extra weights, but we don't know what to do with them.\n                continue\n            if has_renorm:\n                if 'variance:' in weights_name:\n                    # Renorm has variance, but it is not the primary source of truth.\n                    continue\n                # Renorm has moving stddev not variance, undo the transform to make it compatible.\n                if 'stddev:' in weights_name:\n                    weights = np.square(weights) - 1e-5\n                    name = name.replace('stddev', 'variance')\n\n            if self.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES:\n                if name == 'input/conv2d/kernel:0':\n                    # 50 move rule is the 110th input, or 109 starting from 0.\n                    weights[:, 109, :, :] /= 99\n                elif name == 'embedding/kernel:0':\n                    weights[:, 109] /= 99\n\n            pb_name, block, pol_encoder_block, encoder_block = self.tf_name_to_pb_name(\n                name)\n\n            if pb_name is None:\n                raise ValueError(\n                    \"Don't know where to store weight in protobuf: {}\".format(\n                        name))\n\n            if block is None:\n                if pol_encoder_block is not None:\n                    assert pol_encoder_block >= 0\n                    while pol_encoder_block >= len(\n                            self.pb.weights.pol_encoder):\n                        self.pb.weights.pol_encoder.add()\n                    pb_weights = self.pb.weights.pol_encoder[pol_encoder_block]\n                elif encoder_block is not None:\n                    assert encoder_block >= 0\n                    while encoder_block >= len(self.pb.weights.encoder):\n                        self.pb.weights.encoder.add()\n                    pb_weights = self.pb.weights.encoder[encoder_block]\n                else:\n                    pb_weights = self.pb.weights\n            else:\n                assert block >= 0\n                while block >= len(self.pb.weights.residual):\n                    self.pb.weights.residual.add()\n                pb_weights = self.pb.weights.residual[block]\n\n            self.fill_layer_v2(nested_getattr(pb_weights, pb_name), weights)\n\n            if pb_name.endswith('bn_betas'):\n                # Check if we need to add constant one gammas.\n                gamma_name = name.replace('beta', 'gamma')\n                if gamma_name in weight_names:\n                    continue\n                gamma = np.ones(weights.shape)\n                pb_gamma = pb_name.replace('bn_betas', 'bn_gammas')\n                self.fill_layer_v2(nested_getattr(pb_weights, pb_gamma), gamma)\n\n    def fill_net(self, weights):\n        self.weights = []\n        # Batchnorm gammas in ConvBlock?\n        se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT\n        gammas = se\n\n        ws = self.get_weight_amounts()\n\n        blocks = len(weights) - (ws['input'] + ws['head'])\n\n        if blocks % ws['residual'] != 0:\n            raise ValueError(\"Inconsistent number of weights in the file\")\n        blocks //= ws['residual']\n\n        self.pb.format.weights_encoding = pb.Format.LINEAR16\n        self.fill_layer(self.pb.weights.ip2_val_b, weights)\n        self.fill_layer(self.pb.weights.ip2_val_w, weights)\n        self.fill_layer(self.pb.weights.ip1_val_b, weights)\n        self.fill_layer(self.pb.weights.ip1_val_w, weights)\n        self.fill_conv_block(self.pb.weights.value, weights, gammas)\n\n        if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:\n            self.fill_plain_conv(self.pb.weights.policy, weights)\n            self.fill_conv_block(self.pb.weights.policy1, weights, gammas)\n        else:\n            self.fill_layer(self.pb.weights.ip_pol_b, weights)\n            self.fill_layer(self.pb.weights.ip_pol_w, weights)\n            self.fill_conv_block(self.pb.weights.policy, weights, gammas)\n\n        del self.pb.weights.residual[:]\n        tower = []\n        for i in range(blocks):\n            tower.append(self.pb.weights.residual.add())\n\n        for res in reversed(tower):\n            if se:\n                self.fill_se_unit(res.se, weights)\n            self.fill_conv_block(res.conv2, weights, gammas)\n            self.fill_conv_block(res.conv1, weights, gammas)\n\n        self.fill_conv_block(self.pb.weights.input, weights, gammas)\n\n\ndef print_pb_stats(obj, parent=None):\n    for descriptor in obj.DESCRIPTOR.fields:\n        value = getattr(obj, descriptor.name)\n        if descriptor.name == \"weights\":\n            return\n        if descriptor.type == descriptor.TYPE_MESSAGE:\n            if descriptor.label == descriptor.LABEL_REPEATED:\n                map(print_pb_stats, value)\n            else:\n                print_pb_stats(value, obj)\n        elif descriptor.type == descriptor.TYPE_ENUM:\n            enum_name = descriptor.enum_type.values[value].name\n            print(\"%s: %s\" % (descriptor.full_name, enum_name))\n        else:\n            print(\"%s: %s\" % (descriptor.full_name, value))\n\n\ndef main(argv):\n    net = Net()\n\n    if argv.input.endswith(\".txt\"):\n        print('Found .txt network')\n        net.parse_txt(argv.input)\n        net.print_stats()\n        if argv.output == None:\n            argv.output = argv.input.replace('.txt', '.pb.gz')\n            assert argv.output.endswith('.pb.gz')\n            print('Writing output to: {}'.format(argv.output))\n        net.save_proto(argv.output)\n    elif argv.input.endswith(\".pb.gz\"):\n        print('Found .pb.gz network')\n        net.parse_proto(argv.input)\n        net.print_stats()\n        if argv.output == None:\n            argv.output = argv.input.replace('.pb.gz', '.txt.gz')\n            print('Writing output to: {}'.format(argv.output))\n            assert argv.output.endswith('.txt.gz')\n        if argv.output.endswith(\".pb.gz\"):\n            net.save_proto(argv.output)\n        else:\n            net.save_txt(argv.output)\n    else:\n        print('Unable to detect the network format. '\n              'Filename should end in \".txt\" or \".pb.gz\"')\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(\n        description='Convert network textfile to proto.')\n    argparser.add_argument('-i',\n                           '--input',\n                           type=str,\n                           help='input network weight text file')\n    argparser.add_argument('-o',\n                           '--output',\n                           type=str,\n                           help='output filepath without extension')\n    main(argparser.parse_args())\n"
  },
  {
    "path": "tf/net_to_model.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport os\nimport yaml\nimport tfprocess\n\nargparser = argparse.ArgumentParser(description='Convert net to model.')\nargparser.add_argument('net',\n                       type=str,\n                       help='Net file to be converted to a model checkpoint.')\nargparser.add_argument('--start',\n                       type=int,\n                       default=0,\n                       help='Offset to set global_step to.')\nargparser.add_argument('--cfg',\n                       type=argparse.FileType('r'),\n                       help='yaml configuration with training parameters')\nargparser.add_argument('-e',\n                       '--ignore-errors',\n                       action='store_true',\n                       help='Ignore missing and wrong sized values.')\nargs = argparser.parse_args()\ncfg = yaml.safe_load(args.cfg.read())\nprint(yaml.dump(cfg, default_flow_style=False))\nSTART_FROM = args.start\n\ntfp = tfprocess.TFProcess(cfg)\ntfp.init_net()\ntfp.replace_weights(args.net, args.ignore_errors)\ntfp.global_step.assign(START_FROM)\n\nroot_dir = os.path.join(cfg['training']['path'], cfg['name'])\nif not os.path.exists(root_dir):\n    os.makedirs(root_dir)\ntfp.manager.save(checkpoint_number=START_FROM)\nprint(\"Wrote model to {}\".format(tfp.manager.latest_checkpoint))\n"
  },
  {
    "path": "tf/policy_index.py",
    "content": "policy_index = [\n    \"a1b1\", \"a1c1\", \"a1d1\", \"a1e1\", \"a1f1\", \"a1g1\", \"a1h1\", \"a1a2\", \"a1b2\",\n    \"a1c2\", \"a1a3\", \"a1b3\", \"a1c3\", \"a1a4\", \"a1d4\", \"a1a5\", \"a1e5\", \"a1a6\",\n    \"a1f6\", \"a1a7\", \"a1g7\", \"a1a8\", \"a1h8\", \"b1a1\", \"b1c1\", \"b1d1\", \"b1e1\",\n    \"b1f1\", \"b1g1\", \"b1h1\", \"b1a2\", \"b1b2\", \"b1c2\", \"b1d2\", \"b1a3\", \"b1b3\",\n    \"b1c3\", \"b1d3\", \"b1b4\", \"b1e4\", \"b1b5\", \"b1f5\", \"b1b6\", \"b1g6\", \"b1b7\",\n    \"b1h7\", \"b1b8\", \"c1a1\", \"c1b1\", \"c1d1\", \"c1e1\", \"c1f1\", \"c1g1\", \"c1h1\",\n    \"c1a2\", \"c1b2\", \"c1c2\", \"c1d2\", \"c1e2\", \"c1a3\", \"c1b3\", \"c1c3\", \"c1d3\",\n    \"c1e3\", \"c1c4\", \"c1f4\", \"c1c5\", \"c1g5\", \"c1c6\", \"c1h6\", \"c1c7\", \"c1c8\",\n    \"d1a1\", \"d1b1\", \"d1c1\", \"d1e1\", \"d1f1\", \"d1g1\", \"d1h1\", \"d1b2\", \"d1c2\",\n    \"d1d2\", \"d1e2\", \"d1f2\", \"d1b3\", \"d1c3\", \"d1d3\", \"d1e3\", \"d1f3\", \"d1a4\",\n    \"d1d4\", \"d1g4\", \"d1d5\", \"d1h5\", \"d1d6\", \"d1d7\", \"d1d8\", \"e1a1\", \"e1b1\",\n    \"e1c1\", \"e1d1\", \"e1f1\", \"e1g1\", \"e1h1\", \"e1c2\", \"e1d2\", \"e1e2\", \"e1f2\",\n    \"e1g2\", \"e1c3\", \"e1d3\", \"e1e3\", \"e1f3\", \"e1g3\", \"e1b4\", \"e1e4\", \"e1h4\",\n    \"e1a5\", \"e1e5\", \"e1e6\", \"e1e7\", \"e1e8\", \"f1a1\", \"f1b1\", \"f1c1\", \"f1d1\",\n    \"f1e1\", \"f1g1\", \"f1h1\", \"f1d2\", \"f1e2\", \"f1f2\", \"f1g2\", \"f1h2\", \"f1d3\",\n    \"f1e3\", \"f1f3\", \"f1g3\", \"f1h3\", \"f1c4\", \"f1f4\", \"f1b5\", \"f1f5\", \"f1a6\",\n    \"f1f6\", \"f1f7\", \"f1f8\", \"g1a1\", \"g1b1\", \"g1c1\", \"g1d1\", \"g1e1\", \"g1f1\",\n    \"g1h1\", \"g1e2\", \"g1f2\", \"g1g2\", \"g1h2\", \"g1e3\", \"g1f3\", \"g1g3\", \"g1h3\",\n    \"g1d4\", \"g1g4\", \"g1c5\", \"g1g5\", \"g1b6\", \"g1g6\", \"g1a7\", \"g1g7\", \"g1g8\",\n    \"h1a1\", \"h1b1\", \"h1c1\", \"h1d1\", \"h1e1\", \"h1f1\", \"h1g1\", \"h1f2\", \"h1g2\",\n    \"h1h2\", \"h1f3\", \"h1g3\", \"h1h3\", \"h1e4\", \"h1h4\", \"h1d5\", \"h1h5\", \"h1c6\",\n    \"h1h6\", \"h1b7\", \"h1h7\", \"h1a8\", \"h1h8\", \"a2a1\", \"a2b1\", \"a2c1\", \"a2b2\",\n    \"a2c2\", \"a2d2\", \"a2e2\", \"a2f2\", \"a2g2\", \"a2h2\", \"a2a3\", \"a2b3\", \"a2c3\",\n    \"a2a4\", \"a2b4\", \"a2c4\", \"a2a5\", \"a2d5\", \"a2a6\", \"a2e6\", \"a2a7\", \"a2f7\",\n    \"a2a8\", \"a2g8\", \"b2a1\", \"b2b1\", \"b2c1\", \"b2d1\", \"b2a2\", \"b2c2\", \"b2d2\",\n    \"b2e2\", \"b2f2\", \"b2g2\", \"b2h2\", \"b2a3\", \"b2b3\", \"b2c3\", \"b2d3\", \"b2a4\",\n    \"b2b4\", \"b2c4\", \"b2d4\", \"b2b5\", \"b2e5\", \"b2b6\", \"b2f6\", \"b2b7\", \"b2g7\",\n    \"b2b8\", \"b2h8\", \"c2a1\", \"c2b1\", \"c2c1\", \"c2d1\", \"c2e1\", \"c2a2\", \"c2b2\",\n    \"c2d2\", \"c2e2\", \"c2f2\", \"c2g2\", \"c2h2\", \"c2a3\", \"c2b3\", \"c2c3\", \"c2d3\",\n    \"c2e3\", \"c2a4\", \"c2b4\", \"c2c4\", \"c2d4\", \"c2e4\", \"c2c5\", \"c2f5\", \"c2c6\",\n    \"c2g6\", \"c2c7\", \"c2h7\", \"c2c8\", \"d2b1\", \"d2c1\", \"d2d1\", \"d2e1\", \"d2f1\",\n    \"d2a2\", \"d2b2\", \"d2c2\", \"d2e2\", \"d2f2\", \"d2g2\", \"d2h2\", \"d2b3\", \"d2c3\",\n    \"d2d3\", \"d2e3\", \"d2f3\", \"d2b4\", \"d2c4\", \"d2d4\", \"d2e4\", \"d2f4\", \"d2a5\",\n    \"d2d5\", \"d2g5\", \"d2d6\", \"d2h6\", \"d2d7\", \"d2d8\", \"e2c1\", \"e2d1\", \"e2e1\",\n    \"e2f1\", \"e2g1\", \"e2a2\", \"e2b2\", \"e2c2\", \"e2d2\", \"e2f2\", \"e2g2\", \"e2h2\",\n    \"e2c3\", \"e2d3\", \"e2e3\", \"e2f3\", \"e2g3\", \"e2c4\", \"e2d4\", \"e2e4\", \"e2f4\",\n    \"e2g4\", \"e2b5\", \"e2e5\", \"e2h5\", \"e2a6\", \"e2e6\", \"e2e7\", \"e2e8\", \"f2d1\",\n    \"f2e1\", \"f2f1\", \"f2g1\", \"f2h1\", \"f2a2\", \"f2b2\", \"f2c2\", \"f2d2\", \"f2e2\",\n    \"f2g2\", \"f2h2\", \"f2d3\", \"f2e3\", \"f2f3\", \"f2g3\", \"f2h3\", \"f2d4\", \"f2e4\",\n    \"f2f4\", \"f2g4\", \"f2h4\", \"f2c5\", \"f2f5\", \"f2b6\", \"f2f6\", \"f2a7\", \"f2f7\",\n    \"f2f8\", \"g2e1\", \"g2f1\", \"g2g1\", \"g2h1\", \"g2a2\", \"g2b2\", \"g2c2\", \"g2d2\",\n    \"g2e2\", \"g2f2\", \"g2h2\", \"g2e3\", \"g2f3\", \"g2g3\", \"g2h3\", \"g2e4\", \"g2f4\",\n    \"g2g4\", \"g2h4\", \"g2d5\", \"g2g5\", \"g2c6\", \"g2g6\", \"g2b7\", \"g2g7\", \"g2a8\",\n    \"g2g8\", \"h2f1\", \"h2g1\", \"h2h1\", \"h2a2\", \"h2b2\", \"h2c2\", \"h2d2\", \"h2e2\",\n    \"h2f2\", \"h2g2\", \"h2f3\", \"h2g3\", \"h2h3\", \"h2f4\", \"h2g4\", \"h2h4\", \"h2e5\",\n    \"h2h5\", \"h2d6\", \"h2h6\", \"h2c7\", \"h2h7\", \"h2b8\", \"h2h8\", \"a3a1\", \"a3b1\",\n    \"a3c1\", \"a3a2\", \"a3b2\", \"a3c2\", \"a3b3\", \"a3c3\", \"a3d3\", \"a3e3\", \"a3f3\",\n    \"a3g3\", \"a3h3\", \"a3a4\", \"a3b4\", \"a3c4\", \"a3a5\", \"a3b5\", \"a3c5\", \"a3a6\",\n    \"a3d6\", \"a3a7\", \"a3e7\", \"a3a8\", \"a3f8\", \"b3a1\", \"b3b1\", \"b3c1\", \"b3d1\",\n    \"b3a2\", \"b3b2\", \"b3c2\", \"b3d2\", \"b3a3\", \"b3c3\", \"b3d3\", \"b3e3\", \"b3f3\",\n    \"b3g3\", \"b3h3\", \"b3a4\", \"b3b4\", \"b3c4\", \"b3d4\", \"b3a5\", \"b3b5\", \"b3c5\",\n    \"b3d5\", \"b3b6\", \"b3e6\", \"b3b7\", \"b3f7\", \"b3b8\", \"b3g8\", \"c3a1\", \"c3b1\",\n    \"c3c1\", \"c3d1\", \"c3e1\", \"c3a2\", \"c3b2\", \"c3c2\", \"c3d2\", \"c3e2\", \"c3a3\",\n    \"c3b3\", \"c3d3\", \"c3e3\", \"c3f3\", \"c3g3\", \"c3h3\", \"c3a4\", \"c3b4\", \"c3c4\",\n    \"c3d4\", \"c3e4\", \"c3a5\", \"c3b5\", \"c3c5\", \"c3d5\", \"c3e5\", \"c3c6\", \"c3f6\",\n    \"c3c7\", \"c3g7\", \"c3c8\", \"c3h8\", \"d3b1\", \"d3c1\", \"d3d1\", \"d3e1\", \"d3f1\",\n    \"d3b2\", \"d3c2\", \"d3d2\", \"d3e2\", \"d3f2\", \"d3a3\", \"d3b3\", \"d3c3\", \"d3e3\",\n    \"d3f3\", \"d3g3\", \"d3h3\", \"d3b4\", \"d3c4\", \"d3d4\", \"d3e4\", \"d3f4\", \"d3b5\",\n    \"d3c5\", \"d3d5\", \"d3e5\", \"d3f5\", \"d3a6\", \"d3d6\", \"d3g6\", \"d3d7\", \"d3h7\",\n    \"d3d8\", \"e3c1\", \"e3d1\", \"e3e1\", \"e3f1\", \"e3g1\", \"e3c2\", \"e3d2\", \"e3e2\",\n    \"e3f2\", \"e3g2\", \"e3a3\", \"e3b3\", \"e3c3\", \"e3d3\", \"e3f3\", \"e3g3\", \"e3h3\",\n    \"e3c4\", \"e3d4\", \"e3e4\", \"e3f4\", \"e3g4\", \"e3c5\", \"e3d5\", \"e3e5\", \"e3f5\",\n    \"e3g5\", \"e3b6\", \"e3e6\", \"e3h6\", \"e3a7\", \"e3e7\", \"e3e8\", \"f3d1\", \"f3e1\",\n    \"f3f1\", \"f3g1\", \"f3h1\", \"f3d2\", \"f3e2\", \"f3f2\", \"f3g2\", \"f3h2\", \"f3a3\",\n    \"f3b3\", \"f3c3\", \"f3d3\", \"f3e3\", \"f3g3\", \"f3h3\", \"f3d4\", \"f3e4\", \"f3f4\",\n    \"f3g4\", \"f3h4\", \"f3d5\", \"f3e5\", \"f3f5\", \"f3g5\", \"f3h5\", \"f3c6\", \"f3f6\",\n    \"f3b7\", \"f3f7\", \"f3a8\", \"f3f8\", \"g3e1\", \"g3f1\", \"g3g1\", \"g3h1\", \"g3e2\",\n    \"g3f2\", \"g3g2\", \"g3h2\", \"g3a3\", \"g3b3\", \"g3c3\", \"g3d3\", \"g3e3\", \"g3f3\",\n    \"g3h3\", \"g3e4\", \"g3f4\", \"g3g4\", \"g3h4\", \"g3e5\", \"g3f5\", \"g3g5\", \"g3h5\",\n    \"g3d6\", \"g3g6\", \"g3c7\", \"g3g7\", \"g3b8\", \"g3g8\", \"h3f1\", \"h3g1\", \"h3h1\",\n    \"h3f2\", \"h3g2\", \"h3h2\", \"h3a3\", \"h3b3\", \"h3c3\", \"h3d3\", \"h3e3\", \"h3f3\",\n    \"h3g3\", \"h3f4\", \"h3g4\", \"h3h4\", \"h3f5\", \"h3g5\", \"h3h5\", \"h3e6\", \"h3h6\",\n    \"h3d7\", \"h3h7\", \"h3c8\", \"h3h8\", \"a4a1\", \"a4d1\", \"a4a2\", \"a4b2\", \"a4c2\",\n    \"a4a3\", \"a4b3\", \"a4c3\", \"a4b4\", \"a4c4\", \"a4d4\", \"a4e4\", \"a4f4\", \"a4g4\",\n    \"a4h4\", \"a4a5\", \"a4b5\", \"a4c5\", \"a4a6\", \"a4b6\", \"a4c6\", \"a4a7\", \"a4d7\",\n    \"a4a8\", \"a4e8\", \"b4b1\", \"b4e1\", \"b4a2\", \"b4b2\", \"b4c2\", \"b4d2\", \"b4a3\",\n    \"b4b3\", \"b4c3\", \"b4d3\", \"b4a4\", \"b4c4\", \"b4d4\", \"b4e4\", \"b4f4\", \"b4g4\",\n    \"b4h4\", \"b4a5\", \"b4b5\", \"b4c5\", \"b4d5\", \"b4a6\", \"b4b6\", \"b4c6\", \"b4d6\",\n    \"b4b7\", \"b4e7\", \"b4b8\", \"b4f8\", \"c4c1\", \"c4f1\", \"c4a2\", \"c4b2\", \"c4c2\",\n    \"c4d2\", \"c4e2\", \"c4a3\", \"c4b3\", \"c4c3\", \"c4d3\", \"c4e3\", \"c4a4\", \"c4b4\",\n    \"c4d4\", \"c4e4\", \"c4f4\", \"c4g4\", \"c4h4\", \"c4a5\", \"c4b5\", \"c4c5\", \"c4d5\",\n    \"c4e5\", \"c4a6\", \"c4b6\", \"c4c6\", \"c4d6\", \"c4e6\", \"c4c7\", \"c4f7\", \"c4c8\",\n    \"c4g8\", \"d4a1\", \"d4d1\", \"d4g1\", \"d4b2\", \"d4c2\", \"d4d2\", \"d4e2\", \"d4f2\",\n    \"d4b3\", \"d4c3\", \"d4d3\", \"d4e3\", \"d4f3\", \"d4a4\", \"d4b4\", \"d4c4\", \"d4e4\",\n    \"d4f4\", \"d4g4\", \"d4h4\", \"d4b5\", \"d4c5\", \"d4d5\", \"d4e5\", \"d4f5\", \"d4b6\",\n    \"d4c6\", \"d4d6\", \"d4e6\", \"d4f6\", \"d4a7\", \"d4d7\", \"d4g7\", \"d4d8\", \"d4h8\",\n    \"e4b1\", \"e4e1\", \"e4h1\", \"e4c2\", \"e4d2\", \"e4e2\", \"e4f2\", \"e4g2\", \"e4c3\",\n    \"e4d3\", \"e4e3\", \"e4f3\", \"e4g3\", \"e4a4\", \"e4b4\", \"e4c4\", \"e4d4\", \"e4f4\",\n    \"e4g4\", \"e4h4\", \"e4c5\", \"e4d5\", \"e4e5\", \"e4f5\", \"e4g5\", \"e4c6\", \"e4d6\",\n    \"e4e6\", \"e4f6\", \"e4g6\", \"e4b7\", \"e4e7\", \"e4h7\", \"e4a8\", \"e4e8\", \"f4c1\",\n    \"f4f1\", \"f4d2\", \"f4e2\", \"f4f2\", \"f4g2\", \"f4h2\", \"f4d3\", \"f4e3\", \"f4f3\",\n    \"f4g3\", \"f4h3\", \"f4a4\", \"f4b4\", \"f4c4\", \"f4d4\", \"f4e4\", \"f4g4\", \"f4h4\",\n    \"f4d5\", \"f4e5\", \"f4f5\", \"f4g5\", \"f4h5\", \"f4d6\", \"f4e6\", \"f4f6\", \"f4g6\",\n    \"f4h6\", \"f4c7\", \"f4f7\", \"f4b8\", \"f4f8\", \"g4d1\", \"g4g1\", \"g4e2\", \"g4f2\",\n    \"g4g2\", \"g4h2\", \"g4e3\", \"g4f3\", \"g4g3\", \"g4h3\", \"g4a4\", \"g4b4\", \"g4c4\",\n    \"g4d4\", \"g4e4\", \"g4f4\", \"g4h4\", \"g4e5\", \"g4f5\", \"g4g5\", \"g4h5\", \"g4e6\",\n    \"g4f6\", \"g4g6\", \"g4h6\", \"g4d7\", \"g4g7\", \"g4c8\", \"g4g8\", \"h4e1\", \"h4h1\",\n    \"h4f2\", \"h4g2\", \"h4h2\", \"h4f3\", \"h4g3\", \"h4h3\", \"h4a4\", \"h4b4\", \"h4c4\",\n    \"h4d4\", \"h4e4\", \"h4f4\", \"h4g4\", \"h4f5\", \"h4g5\", \"h4h5\", \"h4f6\", \"h4g6\",\n    \"h4h6\", \"h4e7\", \"h4h7\", \"h4d8\", \"h4h8\", \"a5a1\", \"a5e1\", \"a5a2\", \"a5d2\",\n    \"a5a3\", \"a5b3\", \"a5c3\", \"a5a4\", \"a5b4\", \"a5c4\", \"a5b5\", \"a5c5\", \"a5d5\",\n    \"a5e5\", \"a5f5\", \"a5g5\", \"a5h5\", \"a5a6\", \"a5b6\", \"a5c6\", \"a5a7\", \"a5b7\",\n    \"a5c7\", \"a5a8\", \"a5d8\", \"b5b1\", \"b5f1\", \"b5b2\", \"b5e2\", \"b5a3\", \"b5b3\",\n    \"b5c3\", \"b5d3\", \"b5a4\", \"b5b4\", \"b5c4\", \"b5d4\", \"b5a5\", \"b5c5\", \"b5d5\",\n    \"b5e5\", \"b5f5\", \"b5g5\", \"b5h5\", \"b5a6\", \"b5b6\", \"b5c6\", \"b5d6\", \"b5a7\",\n    \"b5b7\", \"b5c7\", \"b5d7\", \"b5b8\", \"b5e8\", \"c5c1\", \"c5g1\", \"c5c2\", \"c5f2\",\n    \"c5a3\", \"c5b3\", \"c5c3\", \"c5d3\", \"c5e3\", \"c5a4\", \"c5b4\", \"c5c4\", \"c5d4\",\n    \"c5e4\", \"c5a5\", \"c5b5\", \"c5d5\", \"c5e5\", \"c5f5\", \"c5g5\", \"c5h5\", \"c5a6\",\n    \"c5b6\", \"c5c6\", \"c5d6\", \"c5e6\", \"c5a7\", \"c5b7\", \"c5c7\", \"c5d7\", \"c5e7\",\n    \"c5c8\", \"c5f8\", \"d5d1\", \"d5h1\", \"d5a2\", \"d5d2\", \"d5g2\", \"d5b3\", \"d5c3\",\n    \"d5d3\", \"d5e3\", \"d5f3\", \"d5b4\", \"d5c4\", \"d5d4\", \"d5e4\", \"d5f4\", \"d5a5\",\n    \"d5b5\", \"d5c5\", \"d5e5\", \"d5f5\", \"d5g5\", \"d5h5\", \"d5b6\", \"d5c6\", \"d5d6\",\n    \"d5e6\", \"d5f6\", \"d5b7\", \"d5c7\", \"d5d7\", \"d5e7\", \"d5f7\", \"d5a8\", \"d5d8\",\n    \"d5g8\", \"e5a1\", \"e5e1\", \"e5b2\", \"e5e2\", \"e5h2\", \"e5c3\", \"e5d3\", \"e5e3\",\n    \"e5f3\", \"e5g3\", \"e5c4\", \"e5d4\", \"e5e4\", \"e5f4\", \"e5g4\", \"e5a5\", \"e5b5\",\n    \"e5c5\", \"e5d5\", \"e5f5\", \"e5g5\", \"e5h5\", \"e5c6\", \"e5d6\", \"e5e6\", \"e5f6\",\n    \"e5g6\", \"e5c7\", \"e5d7\", \"e5e7\", \"e5f7\", \"e5g7\", \"e5b8\", \"e5e8\", \"e5h8\",\n    \"f5b1\", \"f5f1\", \"f5c2\", \"f5f2\", \"f5d3\", \"f5e3\", \"f5f3\", \"f5g3\", \"f5h3\",\n    \"f5d4\", \"f5e4\", \"f5f4\", \"f5g4\", \"f5h4\", \"f5a5\", \"f5b5\", \"f5c5\", \"f5d5\",\n    \"f5e5\", \"f5g5\", \"f5h5\", \"f5d6\", \"f5e6\", \"f5f6\", \"f5g6\", \"f5h6\", \"f5d7\",\n    \"f5e7\", \"f5f7\", \"f5g7\", \"f5h7\", \"f5c8\", \"f5f8\", \"g5c1\", \"g5g1\", \"g5d2\",\n    \"g5g2\", \"g5e3\", \"g5f3\", \"g5g3\", \"g5h3\", \"g5e4\", \"g5f4\", \"g5g4\", \"g5h4\",\n    \"g5a5\", \"g5b5\", \"g5c5\", \"g5d5\", \"g5e5\", \"g5f5\", \"g5h5\", \"g5e6\", \"g5f6\",\n    \"g5g6\", \"g5h6\", \"g5e7\", \"g5f7\", \"g5g7\", \"g5h7\", \"g5d8\", \"g5g8\", \"h5d1\",\n    \"h5h1\", \"h5e2\", \"h5h2\", \"h5f3\", \"h5g3\", \"h5h3\", \"h5f4\", \"h5g4\", \"h5h4\",\n    \"h5a5\", \"h5b5\", \"h5c5\", \"h5d5\", \"h5e5\", \"h5f5\", \"h5g5\", \"h5f6\", \"h5g6\",\n    \"h5h6\", \"h5f7\", \"h5g7\", \"h5h7\", \"h5e8\", \"h5h8\", \"a6a1\", \"a6f1\", \"a6a2\",\n    \"a6e2\", \"a6a3\", \"a6d3\", \"a6a4\", \"a6b4\", \"a6c4\", \"a6a5\", \"a6b5\", \"a6c5\",\n    \"a6b6\", \"a6c6\", \"a6d6\", \"a6e6\", \"a6f6\", \"a6g6\", \"a6h6\", \"a6a7\", \"a6b7\",\n    \"a6c7\", \"a6a8\", \"a6b8\", \"a6c8\", \"b6b1\", \"b6g1\", \"b6b2\", \"b6f2\", \"b6b3\",\n    \"b6e3\", \"b6a4\", \"b6b4\", \"b6c4\", \"b6d4\", \"b6a5\", \"b6b5\", \"b6c5\", \"b6d5\",\n    \"b6a6\", \"b6c6\", \"b6d6\", \"b6e6\", \"b6f6\", \"b6g6\", \"b6h6\", \"b6a7\", \"b6b7\",\n    \"b6c7\", \"b6d7\", \"b6a8\", \"b6b8\", \"b6c8\", \"b6d8\", \"c6c1\", \"c6h1\", \"c6c2\",\n    \"c6g2\", \"c6c3\", \"c6f3\", \"c6a4\", \"c6b4\", \"c6c4\", \"c6d4\", \"c6e4\", \"c6a5\",\n    \"c6b5\", \"c6c5\", \"c6d5\", \"c6e5\", \"c6a6\", \"c6b6\", \"c6d6\", \"c6e6\", \"c6f6\",\n    \"c6g6\", \"c6h6\", \"c6a7\", \"c6b7\", \"c6c7\", \"c6d7\", \"c6e7\", \"c6a8\", \"c6b8\",\n    \"c6c8\", \"c6d8\", \"c6e8\", \"d6d1\", \"d6d2\", \"d6h2\", \"d6a3\", \"d6d3\", \"d6g3\",\n    \"d6b4\", \"d6c4\", \"d6d4\", \"d6e4\", \"d6f4\", \"d6b5\", \"d6c5\", \"d6d5\", \"d6e5\",\n    \"d6f5\", \"d6a6\", \"d6b6\", \"d6c6\", \"d6e6\", \"d6f6\", \"d6g6\", \"d6h6\", \"d6b7\",\n    \"d6c7\", \"d6d7\", \"d6e7\", \"d6f7\", \"d6b8\", \"d6c8\", \"d6d8\", \"d6e8\", \"d6f8\",\n    \"e6e1\", \"e6a2\", \"e6e2\", \"e6b3\", \"e6e3\", \"e6h3\", \"e6c4\", \"e6d4\", \"e6e4\",\n    \"e6f4\", \"e6g4\", \"e6c5\", \"e6d5\", \"e6e5\", \"e6f5\", \"e6g5\", \"e6a6\", \"e6b6\",\n    \"e6c6\", \"e6d6\", \"e6f6\", \"e6g6\", \"e6h6\", \"e6c7\", \"e6d7\", \"e6e7\", \"e6f7\",\n    \"e6g7\", \"e6c8\", \"e6d8\", \"e6e8\", \"e6f8\", \"e6g8\", \"f6a1\", \"f6f1\", \"f6b2\",\n    \"f6f2\", \"f6c3\", \"f6f3\", \"f6d4\", \"f6e4\", \"f6f4\", \"f6g4\", \"f6h4\", \"f6d5\",\n    \"f6e5\", \"f6f5\", \"f6g5\", \"f6h5\", \"f6a6\", \"f6b6\", \"f6c6\", \"f6d6\", \"f6e6\",\n    \"f6g6\", \"f6h6\", \"f6d7\", \"f6e7\", \"f6f7\", \"f6g7\", \"f6h7\", \"f6d8\", \"f6e8\",\n    \"f6f8\", \"f6g8\", \"f6h8\", \"g6b1\", \"g6g1\", \"g6c2\", \"g6g2\", \"g6d3\", \"g6g3\",\n    \"g6e4\", \"g6f4\", \"g6g4\", \"g6h4\", \"g6e5\", \"g6f5\", \"g6g5\", \"g6h5\", \"g6a6\",\n    \"g6b6\", \"g6c6\", \"g6d6\", \"g6e6\", \"g6f6\", \"g6h6\", \"g6e7\", \"g6f7\", \"g6g7\",\n    \"g6h7\", \"g6e8\", \"g6f8\", \"g6g8\", \"g6h8\", \"h6c1\", \"h6h1\", \"h6d2\", \"h6h2\",\n    \"h6e3\", \"h6h3\", \"h6f4\", \"h6g4\", \"h6h4\", \"h6f5\", \"h6g5\", \"h6h5\", \"h6a6\",\n    \"h6b6\", \"h6c6\", \"h6d6\", \"h6e6\", \"h6f6\", \"h6g6\", \"h6f7\", \"h6g7\", \"h6h7\",\n    \"h6f8\", \"h6g8\", \"h6h8\", \"a7a1\", \"a7g1\", \"a7a2\", \"a7f2\", \"a7a3\", \"a7e3\",\n    \"a7a4\", \"a7d4\", \"a7a5\", \"a7b5\", \"a7c5\", \"a7a6\", \"a7b6\", \"a7c6\", \"a7b7\",\n    \"a7c7\", \"a7d7\", \"a7e7\", \"a7f7\", \"a7g7\", \"a7h7\", \"a7a8\", \"a7b8\", \"a7c8\",\n    \"b7b1\", \"b7h1\", \"b7b2\", \"b7g2\", \"b7b3\", \"b7f3\", \"b7b4\", \"b7e4\", \"b7a5\",\n    \"b7b5\", \"b7c5\", \"b7d5\", \"b7a6\", \"b7b6\", \"b7c6\", \"b7d6\", \"b7a7\", \"b7c7\",\n    \"b7d7\", \"b7e7\", \"b7f7\", \"b7g7\", \"b7h7\", \"b7a8\", \"b7b8\", \"b7c8\", \"b7d8\",\n    \"c7c1\", \"c7c2\", \"c7h2\", \"c7c3\", \"c7g3\", \"c7c4\", \"c7f4\", \"c7a5\", \"c7b5\",\n    \"c7c5\", \"c7d5\", \"c7e5\", \"c7a6\", \"c7b6\", \"c7c6\", \"c7d6\", \"c7e6\", \"c7a7\",\n    \"c7b7\", \"c7d7\", \"c7e7\", \"c7f7\", \"c7g7\", \"c7h7\", \"c7a8\", \"c7b8\", \"c7c8\",\n    \"c7d8\", \"c7e8\", \"d7d1\", \"d7d2\", \"d7d3\", \"d7h3\", \"d7a4\", \"d7d4\", \"d7g4\",\n    \"d7b5\", \"d7c5\", \"d7d5\", \"d7e5\", \"d7f5\", \"d7b6\", \"d7c6\", \"d7d6\", \"d7e6\",\n    \"d7f6\", \"d7a7\", \"d7b7\", \"d7c7\", \"d7e7\", \"d7f7\", \"d7g7\", \"d7h7\", \"d7b8\",\n    \"d7c8\", \"d7d8\", \"d7e8\", \"d7f8\", \"e7e1\", \"e7e2\", \"e7a3\", \"e7e3\", \"e7b4\",\n    \"e7e4\", \"e7h4\", \"e7c5\", \"e7d5\", \"e7e5\", \"e7f5\", \"e7g5\", \"e7c6\", \"e7d6\",\n    \"e7e6\", \"e7f6\", \"e7g6\", \"e7a7\", \"e7b7\", \"e7c7\", \"e7d7\", \"e7f7\", \"e7g7\",\n    \"e7h7\", \"e7c8\", \"e7d8\", \"e7e8\", \"e7f8\", \"e7g8\", \"f7f1\", \"f7a2\", \"f7f2\",\n    \"f7b3\", \"f7f3\", \"f7c4\", \"f7f4\", \"f7d5\", \"f7e5\", \"f7f5\", \"f7g5\", \"f7h5\",\n    \"f7d6\", \"f7e6\", \"f7f6\", \"f7g6\", \"f7h6\", \"f7a7\", \"f7b7\", \"f7c7\", \"f7d7\",\n    \"f7e7\", \"f7g7\", \"f7h7\", \"f7d8\", \"f7e8\", \"f7f8\", \"f7g8\", \"f7h8\", \"g7a1\",\n    \"g7g1\", \"g7b2\", \"g7g2\", \"g7c3\", \"g7g3\", \"g7d4\", \"g7g4\", \"g7e5\", \"g7f5\",\n    \"g7g5\", \"g7h5\", \"g7e6\", \"g7f6\", \"g7g6\", \"g7h6\", \"g7a7\", \"g7b7\", \"g7c7\",\n    \"g7d7\", \"g7e7\", \"g7f7\", \"g7h7\", \"g7e8\", \"g7f8\", \"g7g8\", \"g7h8\", \"h7b1\",\n    \"h7h1\", \"h7c2\", \"h7h2\", \"h7d3\", \"h7h3\", \"h7e4\", \"h7h4\", \"h7f5\", \"h7g5\",\n    \"h7h5\", \"h7f6\", \"h7g6\", \"h7h6\", \"h7a7\", \"h7b7\", \"h7c7\", \"h7d7\", \"h7e7\",\n    \"h7f7\", \"h7g7\", \"h7f8\", \"h7g8\", \"h7h8\", \"a8a1\", \"a8h1\", \"a8a2\", \"a8g2\",\n    \"a8a3\", \"a8f3\", \"a8a4\", \"a8e4\", \"a8a5\", \"a8d5\", \"a8a6\", \"a8b6\", \"a8c6\",\n    \"a8a7\", \"a8b7\", \"a8c7\", \"a8b8\", \"a8c8\", \"a8d8\", \"a8e8\", \"a8f8\", \"a8g8\",\n    \"a8h8\", \"b8b1\", \"b8b2\", \"b8h2\", \"b8b3\", \"b8g3\", \"b8b4\", \"b8f4\", \"b8b5\",\n    \"b8e5\", \"b8a6\", \"b8b6\", \"b8c6\", \"b8d6\", \"b8a7\", \"b8b7\", \"b8c7\", \"b8d7\",\n    \"b8a8\", \"b8c8\", \"b8d8\", \"b8e8\", \"b8f8\", \"b8g8\", \"b8h8\", \"c8c1\", \"c8c2\",\n    \"c8c3\", \"c8h3\", \"c8c4\", \"c8g4\", \"c8c5\", \"c8f5\", \"c8a6\", \"c8b6\", \"c8c6\",\n    \"c8d6\", \"c8e6\", \"c8a7\", \"c8b7\", \"c8c7\", \"c8d7\", \"c8e7\", \"c8a8\", \"c8b8\",\n    \"c8d8\", \"c8e8\", \"c8f8\", \"c8g8\", \"c8h8\", \"d8d1\", \"d8d2\", \"d8d3\", \"d8d4\",\n    \"d8h4\", \"d8a5\", \"d8d5\", \"d8g5\", \"d8b6\", \"d8c6\", \"d8d6\", \"d8e6\", \"d8f6\",\n    \"d8b7\", \"d8c7\", \"d8d7\", \"d8e7\", \"d8f7\", \"d8a8\", \"d8b8\", \"d8c8\", \"d8e8\",\n    \"d8f8\", \"d8g8\", \"d8h8\", \"e8e1\", \"e8e2\", \"e8e3\", \"e8a4\", \"e8e4\", \"e8b5\",\n    \"e8e5\", \"e8h5\", \"e8c6\", \"e8d6\", \"e8e6\", \"e8f6\", \"e8g6\", \"e8c7\", \"e8d7\",\n    \"e8e7\", \"e8f7\", \"e8g7\", \"e8a8\", \"e8b8\", \"e8c8\", \"e8d8\", \"e8f8\", \"e8g8\",\n    \"e8h8\", \"f8f1\", \"f8f2\", \"f8a3\", \"f8f3\", \"f8b4\", \"f8f4\", \"f8c5\", \"f8f5\",\n    \"f8d6\", \"f8e6\", \"f8f6\", \"f8g6\", \"f8h6\", \"f8d7\", \"f8e7\", \"f8f7\", \"f8g7\",\n    \"f8h7\", \"f8a8\", \"f8b8\", \"f8c8\", \"f8d8\", \"f8e8\", \"f8g8\", \"f8h8\", \"g8g1\",\n    \"g8a2\", \"g8g2\", \"g8b3\", \"g8g3\", \"g8c4\", \"g8g4\", \"g8d5\", \"g8g5\", \"g8e6\",\n    \"g8f6\", \"g8g6\", \"g8h6\", \"g8e7\", \"g8f7\", \"g8g7\", \"g8h7\", \"g8a8\", \"g8b8\",\n    \"g8c8\", \"g8d8\", \"g8e8\", \"g8f8\", \"g8h8\", \"h8a1\", \"h8h1\", \"h8b2\", \"h8h2\",\n    \"h8c3\", \"h8h3\", \"h8d4\", \"h8h4\", \"h8e5\", \"h8h5\", \"h8f6\", \"h8g6\", \"h8h6\",\n    \"h8f7\", \"h8g7\", \"h8h7\", \"h8a8\", \"h8b8\", \"h8c8\", \"h8d8\", \"h8e8\", \"h8f8\",\n    \"h8g8\", \"a7a8q\", \"a7a8r\", \"a7a8b\", \"a7b8q\", \"a7b8r\", \"a7b8b\", \"b7a8q\",\n    \"b7a8r\", \"b7a8b\", \"b7b8q\", \"b7b8r\", \"b7b8b\", \"b7c8q\", \"b7c8r\", \"b7c8b\",\n    \"c7b8q\", \"c7b8r\", \"c7b8b\", \"c7c8q\", \"c7c8r\", \"c7c8b\", \"c7d8q\", \"c7d8r\",\n    \"c7d8b\", \"d7c8q\", \"d7c8r\", \"d7c8b\", \"d7d8q\", \"d7d8r\", \"d7d8b\", \"d7e8q\",\n    \"d7e8r\", \"d7e8b\", \"e7d8q\", \"e7d8r\", \"e7d8b\", \"e7e8q\", \"e7e8r\", \"e7e8b\",\n    \"e7f8q\", \"e7f8r\", \"e7f8b\", \"f7e8q\", \"f7e8r\", \"f7e8b\", \"f7f8q\", \"f7f8r\",\n    \"f7f8b\", \"f7g8q\", \"f7g8r\", \"f7g8b\", \"g7f8q\", \"g7f8r\", \"g7f8b\", \"g7g8q\",\n    \"g7g8r\", \"g7g8b\", \"g7h8q\", \"g7h8r\", \"g7h8b\", \"h7g8q\", \"h7g8r\", \"h7g8b\",\n    \"h7h8q\", \"h7h8r\", \"h7h8b\"\n]\n"
  },
  {
    "path": "tf/requirements.txt",
    "content": "numpy==1.13.3\ntensorflow==2.5.1\ntensorflow-tensorboard==0.4.0rc2\nprotobuf==3.12.1\n"
  },
  {
    "path": "tf/shufflebuffer.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Chess.\n#    Copyright (C) 2018 Michael O\n#\n#    Leela Chess is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Chess is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Chess.  If not, see <http://www.gnu.org/licenses/>.\n\nimport random\nimport unittest\n\n\nclass ShuffleBuffer:\n    def __init__(self, elem_size, elem_count):\n        \"\"\"\n            A shuffle buffer for fixed sized elements.\n\n            Manages 'elem_count' items in a fixed buffer, each item being exactly\n            'elem_size' bytes.\n        \"\"\"\n        assert elem_size > 0, elem_size\n        assert elem_count > 0, elem_count\n        # Size of each element.\n        self.elem_size = elem_size\n        # Number of elements in the buffer.\n        self.elem_count = elem_count\n        # Fixed size buffer used to hold all the element.\n        self.buffer = bytearray(elem_size * elem_count)\n        # Number of elements actually contained in the buffer.\n        self.used = 0\n\n    def extract(self):\n        \"\"\"\n            Return an item from the shuffle buffer.\n\n            If the buffer is empty, returns None\n        \"\"\"\n        if self.used < 1:\n            return None\n        # The items in the shuffle buffer are held in shuffled order\n        # so returning the last item is sufficient.\n        self.used -= 1\n        i = self.used\n        return self.buffer[i * self.elem_size:(i + 1) * self.elem_size]\n\n    def insert_or_replace(self, item):\n        \"\"\"\n            Inserts 'item' into the shuffle buffer, returning\n            a random item.\n\n            If the buffer is not yet full, returns None\n        \"\"\"\n        assert len(item) == self.elem_size, len(item)\n        # putting the new item in a random location, and appending\n        # the displaced item to the end of the buffer achieves a full\n        # random shuffle (Fisher-Yates)\n        if self.used > 0:\n            # swap 'item' with random item in buffer.\n            i = random.randint(0, self.used - 1)\n            old_item = self.buffer[i * self.elem_size:(i + 1) * self.elem_size]\n            self.buffer[i * self.elem_size:(i + 1) * self.elem_size] = item\n            item = old_item\n        # If the buffer isn't yet full, append 'item' to the end of the buffer.\n        if self.used < self.elem_count:\n            # Not yet full, so place the returned item at the end of the buffer.\n            i = self.used\n            self.buffer[i * self.elem_size:(i + 1) * self.elem_size] = item\n            self.used += 1\n            return None\n        return item\n\n\nclass ShuffleBufferTest(unittest.TestCase):\n    def test_extract(self):\n        sb = ShuffleBuffer(3, 1)\n        r = sb.extract()\n        assert r == None, r  # empty buffer => None\n        r = sb.insert_or_replace(b'111')\n        assert r == None, r  # buffer not yet full => None\n        r = sb.extract()\n        assert r == b'111', r  # one item in buffer => item\n        r = sb.extract()\n        assert r == None, r  # buffer empty => None\n\n    def test_wrong_size(self):\n        sb = ShuffleBuffer(3, 1)\n        try:\n            sb.insert_or_replace(b'1')  # wrong length, so should throw.\n            assert False  # Should not be reached.\n        except:\n            pass\n\n    def test_insert_or_replace(self):\n        n = 10  # number of test items.\n        items = [bytes([x, x, x]) for x in range(n)]\n        sb = ShuffleBuffer(elem_size=3, elem_count=2)\n        out = []\n        for i in items:\n            r = sb.insert_or_replace(i)\n            if not r is None:\n                out.append(r)\n        # Buffer size is 2, 10 items, should be 8 seen so far.\n        assert len(out) == n - 2, len(out)\n        # Get the last two items.\n        out.append(sb.extract())\n        out.append(sb.extract())\n        assert sorted(items) == sorted(out), (items, out)\n        # Check that buffer is empty\n        r = sb.extract()\n        assert r is None, r\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tf/start.sh",
    "content": "#!/usr/bin/env bash\n\nset -e\n\nNETARCHS=(64x6)\nREPO=\"origin\"\nROOT=\"/work/lc0\"\nNETDIR=\"$ROOT/networks/upload\"\nGAMEFILE=\"$HOME/.lc0.dat\"\nLATESTFILE=\"$HOME/.lc0.latest.dat\"\nRAMDISK=\"/ramdisk\"\nMIN_GAP=10\n\nfunction usage()\n{\n  echo \"Starts a training pipeline\"\n  echo \"\"\n  echo \"./start.sh\"\n  echo \"  -h --help\"\n  echo \"  -c --cfg    The configuration directory\"\n  echo \"  -g --games  The number of games between training cycles\"\n  echo \"  -b --branch The git branch to push configs to\"\n  echo \"\"\n  echo \"Example: ./start.sh -c=/tmp/cfgdir -g=40000 -b=test\"\n  echo \"\"\n}\n\nwhile [ \"$1\" != \"\" ]\ndo\n  PARAM=`echo $1 | awk -F= '{print $1}'`\n  VALUE=`echo $1 | awk -F= '{print $2}'`\n  case $PARAM in\n    -h | --help)\n      usage\n      exit\n      ;;\n    -c | --cfg)\n      CONFIGDIR=$VALUE\n      ;;\n    -g | --games)\n      GAMES=$VALUE\n      ;;\n    *)\n      echo \"ERROR: unknown parameter \\\"$PARAM\\\"\"\n      usage\n      exit 1\n      ;;\n  esac\n  shift\ndone\nif [ ! -f \"$GAMEFILE\" ]\nthen\n  echo \"File $GAMEFILE must contain a single number, exiting now!\"\n  exit 1\nfi\n\nif [ -z \"$LC0LOCKFILE\" ]\nthen\n  echo \"env var LC0LOCKFILE not set\"\n  exit 1\nfi\n\n\ngame_num=$(cat $GAMEFILE)\ngame_num=$((game_num + GAMES))\nfile=\"training.${game_num}.gz\"\n\necho \"Starting with '$file' as last game in window\"\n\ntrain() {\n  unbuffer ./train.py --cfg=$1 --output=$2 2>&1 | tee \"$ROOT/logs/$(date +%Y%m%d-%H%M%S).log\"\n  mv -v $2.pb.gz $NETDIR\n}\n\ndelay_count=$((MIN_GAP+1))\n\nwhile true\ndo\n  latest_num=0\n  if [ -f \"$LATESTFILE\" ]\n  then\n    latest_num=$(cat $LATESTFILE)\n  fi\n  if [[ $delay_count -gt $MIN_GAP && ( $latest_num -gt $game_num || -f \"$ROOT/data-rescored/$file\" ) ]]\n  then\n    if [ $latest_num -gt $game_num ]\n    then\n      game_num=$latest_num\n    fi\n    echo \"\"\n\n    # prepare ramdisk\n    (\n    flock -e 200\n    rsync -aq --delete-during $ROOT/split/{train,test} $RAMDISK\n    ) 200>$LC0LOCKFILE\n\n    # train all networks\n    for netarch in ${NETARCHS[@]}\n    do\n      echo \"Training $netarch:\"\n      train \"$CONFIGDIR/$netarch.yaml\" \"${netarch}-$(date +\"%Y_%m%d_%H%M_%S_%3N\")\"\n    done\n\n    # wait for next cycle\n    echo $game_num > $GAMEFILE\n    game_num=$((game_num + GAMES))\n    file=\"training.${game_num}.gz\"\n    echo \"Waiting for '$file'\"\n    delay_count=1\n  else\n    echo -n \".\"\n    sleep 60\n    delay_count=$((delay_count+1))\n  fi\ndone\n"
  },
  {
    "path": "tf/tfprocess.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Zero.\n#    Copyright (C) 2017-2018 Gian-Carlo Pascutto\n#\n#    Leela Zero is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Zero is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Zero.  If not, see <http://www.gnu.org/licenses/>.\n\nimport numpy as np\nimport os\nimport random\nimport tensorflow as tf\nimport time\nimport bisect\nimport lc0_az_policy_map\nimport attention_policy_map as apm\nimport proto.net_pb2 as pb\nfrom functools import reduce\nimport operator\n\nfrom net import Net\n\n\ndef square_relu(x):\n    return tf.nn.relu(x)**2\n\n\nclass Gating(tf.keras.layers.Layer):\n\n    def __init__(self, name=None, additive=True, init_value=None, **kwargs):\n        self.additive = additive\n        if init_value is None:\n            init_value = 0 if self.additive else 1\n        self.init_value = init_value\n        super().__init__(name=name, **kwargs)\n\n    def build(self, input_shape):\n        self.gate = self.add_weight(name='gate',\n                                    shape=input_shape[1:],\n                                    constraint=tf.keras.constraints.NonNeg()\n                                    if not self.additive else None,\n                                    initializer=tf.constant_initializer(\n                                        self.init_value),\n                                    trainable=True)\n\n    def call(self, inputs):\n        return tf.add(inputs, self.gate) if self.additive else tf.multiply(\n            inputs, self.gate)\n\n\ndef ma_gating(inputs, name):\n    out = Gating(name=name + '/mult_gate', additive=False)(inputs)\n    out = Gating(name=name + '/add_gate', additive=True)(out)\n    return out\n\n\nclass ApplySqueezeExcitation(tf.keras.layers.Layer):\n\n    def __init__(self, **kwargs):\n        super(ApplySqueezeExcitation, self).__init__(**kwargs)\n\n    def build(self, input_dimens):\n        self.reshape_size = input_dimens[1][1]\n\n    def call(self, inputs):\n        x = inputs[0]\n        excited = inputs[1]\n        gammas, betas = tf.split(tf.reshape(excited,\n                                            [-1, self.reshape_size, 1, 1]),\n                                 2,\n                                 axis=1)\n        return tf.nn.sigmoid(gammas) * x + betas\n\n\nclass ApplyPolicyMap(tf.keras.layers.Layer):\n\n    def __init__(self, **kwargs):\n        super(ApplyPolicyMap, self).__init__(**kwargs)\n        self.fc1 = tf.constant(lc0_az_policy_map.make_map())\n\n    def call(self, inputs):\n        h_conv_pol_flat = tf.reshape(inputs, [-1, 80 * 8 * 8])\n        return tf.matmul(h_conv_pol_flat,\n                         tf.cast(self.fc1, h_conv_pol_flat.dtype))\n\n\nclass ApplyAttentionPolicyMap(tf.keras.layers.Layer):\n\n    def __init__(self, **kwargs):\n        super(ApplyAttentionPolicyMap, self).__init__(**kwargs)\n        self.fc1 = tf.constant(apm.make_map())\n\n    def call(self, logits, pp_logits):\n        logits = tf.concat([\n            tf.reshape(logits, [-1, 64 * 64]),\n            tf.reshape(pp_logits, [-1, 8 * 24])\n        ],\n                           axis=1)\n        return tf.matmul(logits, tf.cast(self.fc1, logits.dtype))\n\n\nclass Metric:\n\n    def __init__(self, short_name, long_name, suffix='', **kwargs):\n        self.short_name = short_name\n        self.long_name = long_name\n        self.suffix = suffix\n        self.value = 0.0\n        self.count = 0\n\n    def assign(self, value):\n        self.value = value\n        self.count = 1\n\n    def accumulate(self, value):\n        if self.count > 0:\n            self.value = self.value + value\n            self.count = self.count + 1\n        else:\n            self.assign(value)\n\n    def merge(self, other):\n        assert self.short_name == other.short_name\n        self.value = self.value + other.value\n        self.count = self.count + other.count\n\n    def get(self):\n        if self.count == 0:\n            return self.value\n        return self.value / self.count\n\n    def reset(self):\n        self.value = 0.0\n        self.count = 0\n\n\nclass TFProcess:\n\n    def __init__(self, cfg):\n        self.cfg = cfg\n        self.net = Net()\n        self.root_dir = os.path.join(self.cfg['training']['path'],\n                                     self.cfg['name'])\n\n        # Network structure\n        self.RESIDUAL_FILTERS = self.cfg['model'].get('filters', 0)\n        self.RESIDUAL_BLOCKS = self.cfg['model'].get('residual_blocks', 0)\n        self.SE_ratio = self.cfg['model'].get('se_ratio', 0)\n        self.encoder_layers = self.cfg['model'].get('encoder_layers', 0)\n        self.encoder_heads = self.cfg['model'].get('encoder_heads', 2)\n        assert (self.RESIDUAL_BLOCKS > 0) != (self.encoder_layers > 0), \\\n                \"Nets with both encoder layers and residual blocks are not supported\"\n        if self.encoder_layers > 0:\n            self.RESIDUAL_FILTERS = self.cfg['model']['embedding_size']\n            self.embedding_size = self.RESIDUAL_FILTERS\n\n        self.policy_channels = self.cfg['model'].get('policy_channels', 32)\n        self.pol_embedding_size = self.cfg['model'].get(\n            'pol_embedding_size', self.RESIDUAL_FILTERS)\n        self.val_embedding_size = self.cfg['model'].get(\n            'value_embedding_size', 32)\n        self.mov_embedding_size = self.cfg['model'].get(\n            'moves_left_embedding_size', 8)\n        #policy head\n        self.pol_encoder_layers = (0 if self.encoder_layers > 0 else 1)\n        #logic is to explictly warn users who set both in yaml\n        if self.cfg['model'].get('pol_encoder_layers') is not None:\n            self.pol_encoder_layers = self.cfg['model'].get(\n                'pol_encoder_layers')\n        assert not ((self.pol_encoder_layers > 0) and (self.encoder_layers > 0)), \\\n                \"Nets with both body encoder layers and policy encoder layers are not supported\"\n        self.pol_encoder_heads = self.cfg['model'].get('pol_encoder_heads', 2)\n        self.pol_encoder_d_model = self.cfg['model'].get(\n            'pol_encoder_d_model', self.RESIDUAL_FILTERS)\n        self.pol_encoder_dff = self.cfg['model'].get(\n            'pol_encoder_dff', (self.RESIDUAL_FILTERS * 1.5) // 1)\n        self.policy_d_model = self.cfg['model'].get('policy_d_model',\n                                                    self.RESIDUAL_FILTERS)\n\n        #encoder body\n        self.input_gate = self.cfg['model'].get('input_gate')\n        self.encoder_d_model = self.cfg['model'].get('encoder_d_model')\n        self.encoder_dff = self.cfg['model'].get(\n            'encoder_dff', (self.RESIDUAL_FILTERS * 1.5) // 1)\n        self.policy_d_model = self.cfg['model'].get('policy_d_model',\n                                                    self.RESIDUAL_FILTERS)\n        self.arc_encoding = self.cfg['model'].get('arc_encoding', True)\n        self.square_relu_ffn = self.cfg['model'].get('square_relu_ffn', False)\n\n        self.use_smolgen = self.cfg['model'].get('use_smolgen', False)\n        self.smolgen_hidden_channels = self.cfg['model'].get(\n            'smolgen_hidden_channels', 16)\n        self.smolgen_hidden_sz = self.cfg['model'].get('smolgen_hidden_sz',\n                                                       128)\n        self.smolgen_gen_sz = self.cfg['model'].get('smolgen_gen_sz', 128)\n        self.smolgen_activation = self.cfg['model'].get(\n            'smolgen_activation', 'swish')\n\n        self.dropout_rate = self.cfg['model'].get('dropout_rate', 0.0)\n        precision = self.cfg['training'].get('precision', 'single')\n        loss_scale = self.cfg['training'].get('loss_scale', 128)\n        self.virtual_batch_size = self.cfg['model'].get(\n            'virtual_batch_size', None)\n\n        if precision == 'single':\n            self.model_dtype = tf.float32\n        elif precision == 'half':\n            self.model_dtype = tf.float16\n        else:\n            raise ValueError(\"Unknown precision: {}\".format(precision))\n\n        # Scale the loss to prevent gradient underflow\n        self.loss_scale = 1 if self.model_dtype == tf.float32 else loss_scale\n\n        policy_head = self.cfg['model'].get('policy', 'convolution')\n        value_head = self.cfg['model'].get('value', 'wdl')\n        moves_left_head = self.cfg['model'].get('moves_left', 'v1')\n        input_mode = self.cfg['model'].get('input_type', 'classic')\n        default_activation = self.cfg['model'].get('default_activation',\n                                                   'relu')\n\n        self.POLICY_HEAD = None\n        self.VALUE_HEAD = None\n        self.MOVES_LEFT_HEAD = None\n        self.INPUT_MODE = None\n        self.DEFAULT_ACTIVATION = None\n\n        if policy_head == \"classical\":\n            self.POLICY_HEAD = pb.NetworkFormat.POLICY_CLASSICAL\n        elif policy_head == \"convolution\":\n            self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION\n        elif policy_head == \"attention\":\n            self.POLICY_HEAD = pb.NetworkFormat.POLICY_ATTENTION\n            if self.pol_encoder_layers > 0:\n                self.net.set_pol_headcount(self.pol_encoder_heads)\n        else:\n            raise ValueError(\n                \"Unknown policy head format: {}\".format(policy_head))\n\n        self.net.set_policyformat(self.POLICY_HEAD)\n\n        if value_head == \"classical\":\n            self.VALUE_HEAD = pb.NetworkFormat.VALUE_CLASSICAL\n            self.wdl = False\n        elif value_head == \"wdl\":\n            self.VALUE_HEAD = pb.NetworkFormat.VALUE_WDL\n            self.wdl = True\n        else:\n            raise ValueError(\n                \"Unknown value head format: {}\".format(value_head))\n\n        self.net.set_valueformat(self.VALUE_HEAD)\n\n        if moves_left_head == \"none\":\n            self.MOVES_LEFT_HEAD = pb.NetworkFormat.MOVES_LEFT_NONE\n            self.moves_left = False\n        elif moves_left_head == \"v1\":\n            self.MOVES_LEFT_HEAD = pb.NetworkFormat.MOVES_LEFT_V1\n            self.moves_left = True\n        else:\n            raise ValueError(\n                \"Unknown moves left head format: {}\".format(moves_left_head))\n\n        self.net.set_movesleftformat(self.MOVES_LEFT_HEAD)\n\n        if input_mode == \"classic\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE\n        elif input_mode == \"frc_castling\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CASTLING_PLANE\n        elif input_mode == \"canonical\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION\n        elif input_mode == \"canonical_100\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES\n        elif input_mode == \"canonical_armageddon\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON\n        elif input_mode == \"canonical_v2\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2\n        elif input_mode == \"canonical_v2_armageddon\":\n            self.INPUT_MODE = pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON\n        else:\n            raise ValueError(\n                \"Unknown input mode format: {}\".format(input_mode))\n\n        self.net.set_input(self.INPUT_MODE)\n\n        if default_activation == \"relu\":\n            self.net.set_defaultactivation(\n                pb.NetworkFormat.DEFAULT_ACTIVATION_RELU)\n            self.DEFAULT_ACTIVATION = 'relu'\n        elif default_activation == \"mish\":\n            self.net.set_defaultactivation(\n                pb.NetworkFormat.DEFAULT_ACTIVATION_MISH)\n            try:\n                self.DEFAULT_ACTIVATION = tf.keras.activations.mish\n            except AttributeError:\n                import tensorflow_addons as tfa\n                self.DEFAULT_ACTIVATION = tfa.activations.mish\n        else:\n            raise ValueError(\"Unknown default activation type: {}\".format(\n                default_activation))\n\n        if self.encoder_layers > 0:\n            self.net.set_headcount(self.encoder_heads)\n            self.net.set_networkformat(\n                pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT)\n            self.net.set_smolgen_activation(\n                self.net.activation(self.smolgen_activation))\n            self.net.set_ffn_activation(\n                self.net.activation(\n                    'sqrrelu' if self.square_relu_ffn else 'default'))\n\n        self.swa_enabled = self.cfg['training'].get('swa', False)\n\n        # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1)\n        self.swa_max_n = self.cfg['training'].get('swa_max_n', 0)\n\n        self.renorm_enabled = self.cfg['training'].get('renorm', False)\n        self.renorm_max_r = self.cfg['training'].get('renorm_max_r', 1)\n        self.renorm_max_d = self.cfg['training'].get('renorm_max_d', 0)\n        self.renorm_momentum = self.cfg['training'].get(\n            'renorm_momentum', 0.99)\n\n        if self.cfg['gpu'] == 'all':\n            gpus = tf.config.experimental.list_physical_devices('GPU')\n            for gpu in gpus:\n                tf.config.experimental.set_memory_growth(gpu, True)\n            self.strategy = tf.distribute.MirroredStrategy()\n            tf.distribute.experimental_set_strategy(self.strategy)\n        else:\n            gpus = tf.config.experimental.list_physical_devices('GPU')\n            print(gpus)\n            tf.config.experimental.set_visible_devices(gpus[self.cfg['gpu']],\n                                                       'GPU')\n            tf.config.experimental.set_memory_growth(gpus[self.cfg['gpu']],\n                                                     True)\n            self.strategy = None\n        if self.model_dtype == tf.float16:\n            tf.keras.mixed_precision.experimental.set_policy('mixed_float16')\n\n        self.global_step = tf.Variable(0,\n                                       name='global_step',\n                                       trainable=False,\n                                       dtype=tf.int64)\n\n    def init(self, train_dataset, test_dataset, validation_dataset=None):\n        if self.strategy is not None:\n            self.train_dataset = self.strategy.experimental_distribute_dataset(\n                train_dataset)\n        else:\n            self.train_dataset = train_dataset\n        self.train_iter = iter(self.train_dataset)\n        if self.strategy is not None:\n            self.test_dataset = self.strategy.experimental_distribute_dataset(\n                test_dataset)\n        else:\n            self.test_dataset = test_dataset\n        self.test_iter = iter(self.test_dataset)\n        if self.strategy is not None and validation_dataset is not None:\n            self.validation_dataset = self.strategy.experimental_distribute_dataset(\n                validation_dataset)\n        else:\n            self.validation_dataset = validation_dataset\n        if self.strategy is not None:\n            this = self\n            with self.strategy.scope():\n                this.init_net()\n        else:\n            self.init_net()\n\n    def init_net(self):\n        self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))\n        input_var = tf.keras.Input(shape=(112, 8, 8))\n        outputs = self.construct_net(input_var)\n        self.model = tf.keras.Model(inputs=input_var, outputs=outputs)\n\n        # swa_count initialized regardless to make checkpoint code simpler.\n        self.swa_count = tf.Variable(0., name='swa_count', trainable=False)\n        self.swa_weights = None\n        if self.swa_enabled:\n            # Count of networks accumulated into SWA\n            self.swa_weights = [\n                tf.Variable(w, trainable=False) for w in self.model.weights\n            ]\n\n        self.active_lr = tf.Variable(0.01, trainable=False)\n        # All 'new' (TF 2.10 or newer non-legacy) optimizers must have learning_rate updated manually.\n        self.update_lr_manually = False\n        # Be sure not to set new_optimizer before TF 2.11, or unless you edit the code to specify a new optimizer explicitly.\n        if self.cfg['training'].get('new_optimizer'):\n            self.optimizer = tf.keras.optimizers.SGD(\n                learning_rate=self.active_lr, momentum=0.9, nesterov=True)\n            self.update_lr_manually = True\n        else:\n            try:\n                self.optimizer = tf.keras.optimizers.legacy.SGD(\n                    learning_rate=lambda: self.active_lr,\n                    momentum=0.9,\n                    nesterov=True)\n            except AttributeError:\n                self.optimizer = tf.keras.optimizers.SGD(\n                    learning_rate=lambda: self.active_lr,\n                    momentum=0.9,\n                    nesterov=True)\n        self.orig_optimizer = self.optimizer\n        try:\n            self.aggregator = self.orig_optimizer.aggregate_gradients\n        except AttributeError:\n            self.aggregator = self.orig_optimizer.gradient_aggregator\n        if self.loss_scale != 1:\n            self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(\n                self.optimizer, self.loss_scale)\n        if self.cfg['training'].get('lookahead_optimizer'):\n            import tensorflow_addons as tfa\n            self.optimizer = tfa.optimizers.Lookahead(self.optimizer)\n\n        def correct_policy(target, output):\n            output = tf.cast(output, tf.float32)\n            # Calculate loss on policy head\n            if self.cfg['training'].get('mask_legal_moves'):\n                # extract mask for legal moves from target policy\n                move_is_legal = tf.greater_equal(target, 0)\n                # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient\n                illegal_filler = tf.zeros_like(output) - 1.0e10\n                output = tf.where(move_is_legal, output, illegal_filler)\n            # y_ still has -1 on illegal moves, flush them to 0\n            target = tf.nn.relu(target)\n            return target, output\n\n        def policy_loss(target, output):\n            target, output = correct_policy(target, output)\n            policy_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(\n                labels=tf.stop_gradient(target), logits=output)\n            return tf.reduce_mean(input_tensor=policy_cross_entropy)\n\n        self.policy_loss_fn = policy_loss\n\n        def policy_accuracy(target, output):\n            target, output = correct_policy(target, output)\n            return tf.reduce_mean(\n                tf.cast(\n                    tf.equal(tf.argmax(input=target, axis=1),\n                             tf.argmax(input=output, axis=1)), tf.float32))\n\n        self.policy_accuracy_fn = policy_accuracy\n\n        def moves_left_mean_error_fn(target, output):\n            output = tf.cast(output, tf.float32)\n            return tf.reduce_mean(tf.abs(target - output))\n\n        self.moves_left_mean_error = moves_left_mean_error_fn\n\n        def policy_entropy(target, output):\n            target, output = correct_policy(target, output)\n            softmaxed = tf.nn.softmax(output)\n            return tf.math.negative(\n                tf.reduce_mean(\n                    tf.reduce_sum(tf.math.xlogy(softmaxed, softmaxed),\n                                  axis=1)))\n\n        self.policy_entropy_fn = policy_entropy\n\n        def policy_uniform_loss(target, output):\n            uniform = tf.where(tf.greater_equal(target, 0),\n                               tf.ones_like(target), tf.zeros_like(target))\n            balanced_uniform = uniform / tf.reduce_sum(\n                uniform, axis=1, keepdims=True)\n            target, output = correct_policy(target, output)\n            policy_cross_entropy = \\\n                tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(balanced_uniform),\n                                                        logits=output)\n            return tf.reduce_mean(input_tensor=policy_cross_entropy)\n\n        self.policy_uniform_loss_fn = policy_uniform_loss\n\n        q_ratio = self.cfg['training'].get('q_ratio', 0)\n        assert 0 <= q_ratio <= 1\n\n        # Linear conversion to scalar to compute MSE with, for comparison to old values\n        wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1)\n\n        self.qMix = lambda z, q: q * q_ratio + z * (1 - q_ratio)\n        # Loss on value head\n        if self.wdl:\n\n            def value_loss(target, output):\n                output = tf.cast(output, tf.float32)\n                value_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(\n                    labels=tf.stop_gradient(target), logits=output)\n                return tf.reduce_mean(input_tensor=value_cross_entropy)\n\n            self.value_loss_fn = value_loss\n\n            def mse_loss(target, output):\n                output = tf.cast(output, tf.float32)\n                scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl)\n                scalar_target = tf.matmul(target, wdl)\n                return tf.reduce_mean(input_tensor=tf.math.squared_difference(\n                    scalar_target, scalar_z_conv))\n\n            self.mse_loss_fn = mse_loss\n        else:\n\n            def value_loss(target, output):\n                return tf.constant(0)\n\n            self.value_loss_fn = value_loss\n\n            def mse_loss(target, output):\n                output = tf.cast(output, tf.float32)\n                scalar_target = tf.matmul(target, wdl)\n                return tf.reduce_mean(input_tensor=tf.math.squared_difference(\n                    scalar_target, output))\n\n            self.mse_loss_fn = mse_loss\n\n        if self.moves_left:\n\n            def moves_left_loss(target, output):\n                # Scale the loss to similar range as other losses.\n                scale = 20.0\n                target = target / scale\n                output = tf.cast(output, tf.float32) / scale\n                if self.strategy is not None:\n                    huber = tf.keras.losses.Huber(\n                        10.0 / scale, reduction=tf.keras.losses.Reduction.NONE)\n                else:\n                    huber = tf.keras.losses.Huber(10.0 / scale)\n                return tf.reduce_mean(huber(target, output))\n        else:\n            moves_left_loss = None\n\n        self.moves_left_loss_fn = moves_left_loss\n\n        pol_loss_w = self.cfg['training']['policy_loss_weight']\n        val_loss_w = self.cfg['training']['value_loss_weight']\n\n        if self.moves_left:\n            moves_loss_w = self.cfg['training']['moves_left_loss_weight']\n        else:\n            moves_loss_w = tf.constant(0.0, dtype=tf.float32)\n        reg_term_w = self.cfg['training'].get('reg_term_weight', 1.0)\n\n        def _lossMix(policy, value, moves_left, reg_term):\n            return pol_loss_w * policy + val_loss_w * value + moves_loss_w * moves_left + reg_term_w * reg_term\n\n        self.lossMix = _lossMix\n\n        def accuracy(target, output):\n            output = tf.cast(output, tf.float32)\n            return tf.reduce_mean(\n                tf.cast(\n                    tf.equal(tf.argmax(input=target, axis=1),\n                             tf.argmax(input=output, axis=1)), tf.float32))\n\n        self.accuracy_fn = accuracy\n\n        # Order must match the order in process_inner_loop\n        self.train_metrics = [\n            Metric('P', 'Policy Loss'),\n            Metric('V', 'Value Loss'),\n            Metric('ML', 'Moves Left Loss'),\n            Metric('Reg', 'Reg term'),\n            Metric('Total', 'Total Loss'),\n            Metric(\n                'V MSE', 'MSE Loss'\n            ),  # Long name here doesn't mention value for backwards compatibility reasons.\n        ]\n        self.time_start = None\n        self.last_steps = None\n\n        # Order must match the order in calculate_test_summaries_inner_loop\n        self.test_metrics = [\n            Metric('P', 'Policy Loss'),\n            Metric('V', 'Value Loss'),\n            Metric('ML', 'Moves Left Loss'),\n            Metric(\n                'V MSE', 'MSE Loss'\n            ),  # Long name here doesn't mention value for backwards compatibility reasons.\n            Metric('P Acc', 'Policy Accuracy', suffix='%'),\n            Metric('V Acc', 'Value Accuracy', suffix='%'),\n            Metric('ML Mean', 'Moves Left Mean Error'),\n            Metric('P Entropy', 'Policy Entropy'),\n            Metric('P UL', 'Policy UL'),\n        ]\n\n        # Set adaptive learning rate during training\n        self.cfg['training']['lr_boundaries'].sort()\n        self.warmup_steps = self.cfg['training'].get('warmup_steps', 0)\n        self.lr = self.cfg['training']['lr_values'][0]\n        self.test_writer = tf.summary.create_file_writer(\n            os.path.join(os.getcwd(),\n                         \"leelalogs/{}-test\".format(self.cfg['name'])))\n        self.train_writer = tf.summary.create_file_writer(\n            os.path.join(os.getcwd(),\n                         \"leelalogs/{}-train\".format(self.cfg['name'])))\n        if vars(self).get('validation_dataset', None) is not None:\n            self.validation_writer = tf.summary.create_file_writer(\n                os.path.join(\n                    os.getcwd(),\n                    \"leelalogs/{}-validation\".format(self.cfg['name'])))\n        if self.swa_enabled:\n            self.swa_writer = tf.summary.create_file_writer(\n                os.path.join(os.getcwd(),\n                             \"leelalogs/{}-swa-test\".format(self.cfg['name'])))\n            self.swa_validation_writer = tf.summary.create_file_writer(\n                os.path.join(\n                    os.getcwd(),\n                    \"leelalogs/{}-swa-validation\".format(self.cfg['name'])))\n        self.checkpoint = tf.train.Checkpoint(optimizer=self.orig_optimizer,\n                                              model=self.model,\n                                              global_step=self.global_step,\n                                              swa_count=self.swa_count)\n        self.checkpoint.listed = self.swa_weights\n        self.manager = tf.train.CheckpointManager(\n            self.checkpoint,\n            directory=self.root_dir,\n            max_to_keep=50,\n            keep_checkpoint_every_n_hours=24,\n            checkpoint_name=self.cfg['name'])\n\n    def replace_weights(self, proto_filename, ignore_errors=False):\n        self.net.parse_proto(proto_filename)\n\n        filters, blocks = self.net.filters(), self.net.blocks()\n        if not ignore_errors:\n            if self.RESIDUAL_FILTERS != filters:\n                raise ValueError(\"Number of filters doesn't match the network\")\n            if self.RESIDUAL_BLOCKS != blocks:\n                raise ValueError(\"Number of blocks doesn't match the network\")\n            if self.POLICY_HEAD != self.net.pb.format.network_format.policy:\n                raise ValueError(\"Policy head type doesn't match the network\")\n            if self.VALUE_HEAD != self.net.pb.format.network_format.value:\n                raise ValueError(\"Value head type doesn't match the network\")\n\n        # List all tensor names we need weights for.\n        names = []\n        for weight in self.model.weights:\n            names.append(weight.name)\n\n        new_weights = self.net.get_weights_v2(names)\n        for weight in self.model.weights:\n            if 'renorm' in weight.name:\n                # Renorm variables are not populated.\n                continue\n\n            try:\n                new_weight = new_weights[weight.name]\n            except KeyError:\n                error_string = 'No values for tensor {} in protobuf'.format(\n                    weight.name)\n                if ignore_errors:\n                    print(error_string)\n                    continue\n                else:\n                    raise KeyError(error_string)\n\n            if reduce(operator.mul, weight.shape.as_list(),\n                      1) != len(new_weight):\n                error_string = 'Tensor {} has wrong length. Tensorflow shape {}, size in protobuf {}'.format(\n                    weight.name, weight.shape.as_list(), len(new_weight))\n                if ignore_errors:\n                    print(error_string)\n                    continue\n                else:\n                    raise KeyError(error_string)\n\n            if weight.shape.ndims == 4:\n                # Rescale rule50 related weights as clients do not normalize the input.\n                if weight.name == 'input/conv2d/kernel:0' and self.net.pb.format.network_format.input < pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES:\n                    num_inputs = 112\n                    # 50 move rule is the 110th input, or 109 starting from 0.\n                    rule50_input = 109\n                    for i in range(len(new_weight)):\n                        if (i % (num_inputs * 9)) // 9 == rule50_input:\n                            new_weight[i] = new_weight[i] * 99\n\n                # Convolution weights need a transpose\n                #\n                # TF (kYXInputOutput)\n                # [filter_height, filter_width, in_channels, out_channels]\n                #\n                # Leela/cuDNN/Caffe (kOutputInputYX)\n                # [output, input, filter_size, filter_size]\n                s = weight.shape.as_list()\n                shape = [s[i] for i in [3, 2, 0, 1]]\n                new_weight = tf.constant(new_weight, shape=shape)\n                weight.assign(tf.transpose(a=new_weight, perm=[2, 3, 1, 0]))\n            elif weight.shape.ndims == 2:\n                # Fully connected layers are [in, out] in TF\n                #\n                # [out, in] in Leela\n                #\n                s = weight.shape.as_list()\n                shape = [s[i] for i in [1, 0]]\n                new_weight = tf.constant(new_weight, shape=shape)\n                weight.assign(tf.transpose(a=new_weight, perm=[1, 0]))\n            else:\n                # Biases, batchnorm etc\n                new_weight = tf.constant(new_weight, shape=weight.shape)\n                weight.assign(new_weight)\n        # Replace the SWA weights as well, ensuring swa accumulation is reset.\n        if self.swa_enabled:\n            self.swa_count.assign(tf.constant(0.))\n            self.update_swa()\n        # This should result in identical file to the starting one\n        # self.save_leelaz_weights('restored.pb.gz')\n\n    def restore(self):\n        if self.manager.latest_checkpoint is not None:\n            print(\"Restoring from {0}\".format(self.manager.latest_checkpoint))\n            self.checkpoint.restore(self.manager.latest_checkpoint)\n\n    def process_loop(self, batch_size, test_batches, batch_splits=1):\n        if self.swa_enabled:\n            # split half of test_batches between testing regular weights and SWA weights\n            test_batches //= 2\n        # Make sure that ghost batch norm can be applied\n        if self.virtual_batch_size and batch_size % self.virtual_batch_size != 0:\n            # Adjust required batch size for batch splitting.\n            required_factor = self.virtual_batch_size * self.cfg[\n                'training'].get('num_batch_splits', 1)\n            raise ValueError(\n                'batch_size must be a multiple of {}'.format(required_factor))\n\n        # Get the initial steps value in case this is a resume from a step count\n        # which is not a multiple of total_steps.\n        steps = self.global_step.read_value()\n        self.last_steps = steps\n        self.time_start = time.time()\n        self.profiling_start_step = None\n\n        total_steps = self.cfg['training']['total_steps']\n        for _ in range(steps % total_steps, total_steps):\n            self.process(batch_size, test_batches, batch_splits=batch_splits)\n\n    @tf.function()\n    def read_weights(self):\n        return [w.read_value() for w in self.model.weights]\n\n    @tf.function()\n    def process_inner_loop(self, x, y, z, q, m):\n        with tf.GradientTape() as tape:\n            outputs = self.model(x, training=True)\n            policy = outputs[0]\n            value = outputs[1]\n            policy_loss = self.policy_loss_fn(y, policy)\n            reg_term = sum(self.model.losses)\n            if self.wdl:\n                value_ce_loss = self.value_loss_fn(self.qMix(z, q), value)\n                value_loss = value_ce_loss\n            else:\n                value_mse_loss = self.mse_loss_fn(self.qMix(z, q), value)\n                value_loss = value_mse_loss\n            if self.moves_left:\n                moves_left = outputs[2]\n                moves_left_loss = self.moves_left_loss_fn(m, moves_left)\n            else:\n                moves_left_loss = tf.constant(0.)\n            total_loss = self.lossMix(policy_loss, value_loss, moves_left_loss,\n                                      reg_term)\n            if self.loss_scale != 1:\n                total_loss = self.optimizer.get_scaled_loss(total_loss)\n        if self.wdl:\n            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)\n        else:\n            value_loss = self.value_loss_fn(self.qMix(z, q), value)\n        metrics = [\n            policy_loss,\n            value_loss,\n            moves_left_loss,\n            reg_term,\n            total_loss,\n            # Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to\n            # get comparable values.\n            mse_loss / 4.0,\n        ]\n        return metrics, tape.gradient(total_loss, self.model.trainable_weights)\n\n    @tf.function()\n    def strategy_process_inner_loop(self, x, y, z, q, m):\n        metrics, new_grads = self.strategy.run(self.process_inner_loop,\n                                               args=(x, y, z, q, m))\n        metrics = [\n            self.strategy.reduce(tf.distribute.ReduceOp.MEAN, m, axis=None)\n            for m in metrics\n        ]\n        return metrics, new_grads\n\n    def apply_grads(self, grads, effective_batch_splits):\n        grads = [\n            g[0]\n            for g in self.aggregator(zip(grads, self.model.trainable_weights))\n        ]\n        if self.loss_scale != 1:\n            grads = self.optimizer.get_unscaled_gradients(grads)\n        max_grad_norm = self.cfg['training'].get(\n            'max_grad_norm', 10000.0) * effective_batch_splits\n        grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)\n        self.optimizer.apply_gradients(zip(grads,\n                                           self.model.trainable_weights),\n                                       experimental_aggregate_gradients=False)\n        return grad_norm\n\n    @tf.function()\n    def strategy_apply_grads(self, grads, effective_batch_splits):\n        grad_norm = self.strategy.run(self.apply_grads,\n                                      args=(grads, effective_batch_splits))\n        grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,\n                                         grad_norm,\n                                         axis=None)\n        return grad_norm\n\n    @tf.function()\n    def merge_grads(self, grads, new_grads):\n        return [tf.math.add(a, b) for (a, b) in zip(grads, new_grads)]\n\n    @tf.function()\n    def strategy_merge_grads(self, grads, new_grads):\n        return self.strategy.run(self.merge_grads, args=(grads, new_grads))\n\n    def train_step(self, steps, batch_size, batch_splits):\n        # need to add 1 to steps because steps will be incremented after gradient update\n        if (steps +\n                1) % self.cfg['training']['train_avg_report_steps'] == 0 or (\n                    steps + 1) % self.cfg['training']['total_steps'] == 0:\n            before_weights = self.read_weights()\n\n        # Run training for this batch\n        grads = None\n        for _ in range(batch_splits):\n            x, y, z, q, m = next(self.train_iter)\n            if self.strategy is not None:\n                metrics, new_grads = self.strategy_process_inner_loop(\n                    x, y, z, q, m)\n            else:\n                metrics, new_grads = self.process_inner_loop(x, y, z, q, m)\n            if not grads:\n                grads = new_grads\n            else:\n                if self.strategy is not None:\n                    grads = self.strategy_merge_grads(grads, new_grads)\n                else:\n                    grads = self.merge_grads(grads, new_grads)\n            # Keep running averages\n            for acc, val in zip(self.train_metrics, metrics):\n                acc.accumulate(val)\n        # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this.\n        effective_batch_splits = batch_splits\n        if self.strategy is not None:\n            effective_batch_splits = batch_splits * self.strategy.num_replicas_in_sync\n        self.active_lr.assign(self.lr / effective_batch_splits)\n        if self.update_lr_manually:\n            self.orig_optimizer.learning_rate = self.active_lr\n        if self.strategy is not None:\n            grad_norm = self.strategy_apply_grads(grads,\n                                                  effective_batch_splits)\n        else:\n            grad_norm = self.apply_grads(grads, effective_batch_splits)\n\n        # Note: grads variable at this point has not been unscaled or\n        # had clipping applied. Since no code after this point depends\n        # upon that it seems fine for now.\n\n        # Update steps.\n        self.global_step.assign_add(1)\n        steps = self.global_step.read_value()\n\n        if steps % self.cfg['training'][\n                'train_avg_report_steps'] == 0 or steps % self.cfg['training'][\n                    'total_steps'] == 0:\n            time_end = time.time()\n            speed = 0\n            if self.time_start:\n                elapsed = time_end - self.time_start\n                steps_elapsed = steps - self.last_steps\n                speed = batch_size * (tf.cast(steps_elapsed, tf.float32) /\n                                      elapsed)\n            print(\"step {}, lr={:g}\".format(steps, self.lr), end='')\n            for metric in self.train_metrics:\n                print(\" {}={:g}{}\".format(metric.short_name, metric.get(),\n                                          metric.suffix),\n                      end='')\n            print(\" ({:g} pos/s)\".format(speed))\n\n            after_weights = self.read_weights()\n            with self.train_writer.as_default():\n                for metric in self.train_metrics:\n                    tf.summary.scalar(metric.long_name,\n                                      metric.get(),\n                                      step=steps)\n                tf.summary.scalar(\"LR\", self.lr, step=steps)\n                tf.summary.scalar(\"Gradient norm\",\n                                  grad_norm / effective_batch_splits,\n                                  step=steps)\n                self.compute_update_ratio(before_weights, after_weights, steps)\n            self.train_writer.flush()\n\n            self.time_start = time_end\n            self.last_steps = steps\n            for metric in self.train_metrics:\n                metric.reset()\n        return steps\n\n    def process(self, batch_size, test_batches, batch_splits):\n        # Get the initial steps value before we do a training step.\n        steps = self.global_step.read_value()\n\n        # By default disabled since 0 != 10.\n        if steps % self.cfg['training'].get('profile_step_freq',\n                                            1) == self.cfg['training'].get(\n                                                'profile_step_offset', 10):\n            self.profiling_start_step = steps\n            tf.profiler.experimental.start(\n                os.path.join(os.getcwd(),\n                             \"leelalogs/{}-profile\".format(self.cfg['name'])))\n\n        # Run test before first step to see delta since end of last run.\n        if steps % self.cfg['training']['total_steps'] == 0:\n            with tf.profiler.experimental.Trace(\"Test\", step_num=steps + 1):\n                # Steps is given as one higher than current in order to avoid it\n                # being equal to the value the end of a run is stored against.\n                self.calculate_test_summaries(test_batches, steps + 1)\n                if self.swa_enabled:\n                    self.calculate_swa_summaries(test_batches, steps + 1)\n\n        # Determine learning rate\n        lr_values = self.cfg['training']['lr_values']\n        lr_boundaries = self.cfg['training']['lr_boundaries']\n        steps_total = steps % self.cfg['training']['total_steps']\n        self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)]\n        if self.warmup_steps > 0 and steps < self.warmup_steps:\n            self.lr = self.lr * tf.cast(steps + 1,\n                                        tf.float32) / self.warmup_steps\n\n        with tf.profiler.experimental.Trace(\"Train\", step_num=steps):\n            steps = self.train_step(steps, batch_size, batch_splits)\n\n        if self.swa_enabled and steps % self.cfg['training']['swa_steps'] == 0:\n            self.update_swa()\n\n        # Calculate test values every 'test_steps', but also ensure there is\n        # one at the final step so the delta to the first step can be calculated.\n        if steps % self.cfg['training']['test_steps'] == 0 or steps % self.cfg[\n                'training']['total_steps'] == 0:\n            with tf.profiler.experimental.Trace(\"Test\", step_num=steps):\n                self.calculate_test_summaries(test_batches, steps)\n                if self.swa_enabled:\n                    self.calculate_swa_summaries(test_batches, steps)\n\n        if self.validation_dataset is not None and (\n                steps % self.cfg['training']['validation_steps'] == 0\n                or steps % self.cfg['training']['total_steps'] == 0):\n            with tf.profiler.experimental.Trace(\"Validate\", step_num=steps):\n                if self.swa_enabled:\n                    self.calculate_swa_validations(steps)\n                else:\n                    self.calculate_test_validations(steps)\n\n        # Save session and weights at end, and also optionally every 'checkpoint_steps'.\n        if steps % self.cfg['training']['total_steps'] == 0 or (\n                'checkpoint_steps' in self.cfg['training']\n                and steps % self.cfg['training']['checkpoint_steps'] == 0):\n            evaled_steps = steps.numpy()\n            self.manager.save(checkpoint_number=evaled_steps)\n            print(\"Model saved in file: {}\".format(\n                self.manager.latest_checkpoint))\n            path = os.path.join(self.root_dir, self.cfg['name'])\n            leela_path = path + \"-\" + str(evaled_steps)\n            swa_path = path + \"-swa-\" + str(evaled_steps)\n            self.net.pb.training_params.training_steps = evaled_steps\n            self.save_leelaz_weights(leela_path)\n            if self.swa_enabled:\n                self.save_swa_weights(swa_path)\n\n        if self.profiling_start_step is not None and (\n                steps >= self.profiling_start_step +\n                self.cfg['training'].get('profile_step_count', 0)\n                or steps % self.cfg['training']['total_steps'] == 0):\n            tf.profiler.experimental.stop()\n            self.profiling_start_step = None\n\n    def calculate_swa_summaries(self, test_batches, steps):\n        backup = self.read_weights()\n        for (swa, w) in zip(self.swa_weights, self.model.weights):\n            w.assign(swa.read_value())\n        true_test_writer, self.test_writer = self.test_writer, self.swa_writer\n        print('swa', end=' ')\n        self.calculate_test_summaries(test_batches, steps)\n        self.test_writer = true_test_writer\n        for (old, w) in zip(backup, self.model.weights):\n            w.assign(old)\n\n    @tf.function()\n    def calculate_test_summaries_inner_loop(self, x, y, z, q, m):\n        outputs = self.model(x, training=False)\n        policy = outputs[0]\n        value = outputs[1]\n        policy_loss = self.policy_loss_fn(y, policy)\n        policy_accuracy = self.policy_accuracy_fn(y, policy)\n        policy_entropy = self.policy_entropy_fn(y, policy)\n        policy_ul = self.policy_uniform_loss_fn(y, policy)\n        if self.wdl:\n            value_loss = self.value_loss_fn(self.qMix(z, q), value)\n            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)\n            value_accuracy = self.accuracy_fn(self.qMix(z, q), value)\n        else:\n            value_loss = self.value_loss_fn(self.qMix(z, q), value)\n            mse_loss = self.mse_loss_fn(self.qMix(z, q), value)\n            value_accuracy = tf.constant(0.)\n        if self.moves_left:\n            moves_left = outputs[2]\n            moves_left_loss = self.moves_left_loss_fn(m, moves_left)\n            moves_left_mean_error = self.moves_left_mean_error(m, moves_left)\n        else:\n            moves_left_loss = tf.constant(0.)\n            moves_left_mean_error = tf.constant(0.)\n        metrics = [\n            policy_loss,\n            value_loss,\n            moves_left_loss,\n            mse_loss / 4,\n            policy_accuracy * 100,\n            value_accuracy * 100,\n            moves_left_mean_error,\n            policy_entropy,\n            policy_ul,\n        ]\n        return metrics\n\n    @tf.function()\n    def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m):\n        metrics = self.strategy.run(self.calculate_test_summaries_inner_loop,\n                                    args=(x, y, z, q, m))\n        metrics = [\n            self.strategy.reduce(tf.distribute.ReduceOp.MEAN, m, axis=None)\n            for m in metrics\n        ]\n        return metrics\n\n    def calculate_test_summaries(self, test_batches, steps):\n        for metric in self.test_metrics:\n            metric.reset()\n        for _ in range(0, test_batches):\n            x, y, z, q, m = next(self.test_iter)\n            if self.strategy is not None:\n                metrics = self.strategy_calculate_test_summaries_inner_loop(\n                    x, y, z, q, m)\n            else:\n                metrics = self.calculate_test_summaries_inner_loop(\n                    x, y, z, q, m)\n            for acc, val in zip(self.test_metrics, metrics):\n                acc.accumulate(val)\n        self.net.pb.training_params.learning_rate = self.lr\n        self.net.pb.training_params.mse_loss = self.test_metrics[3].get()\n        self.net.pb.training_params.policy_loss = self.test_metrics[0].get()\n        # TODO store value and value accuracy in pb\n        self.net.pb.training_params.accuracy = self.test_metrics[4].get()\n        with self.test_writer.as_default():\n            for metric in self.test_metrics:\n                tf.summary.scalar(metric.long_name, metric.get(), step=steps)\n            for w in self.model.weights:\n                tf.summary.histogram(w.name, w, step=steps)\n        self.test_writer.flush()\n\n        print(\"step {},\".format(steps), end='')\n        for metric in self.test_metrics:\n            print(\" {}={:g}{}\".format(metric.short_name, metric.get(),\n                                      metric.suffix),\n                  end='')\n        print()\n\n    def calculate_swa_validations(self, steps):\n        backup = self.read_weights()\n        for (swa, w) in zip(self.swa_weights, self.model.weights):\n            w.assign(swa.read_value())\n        true_validation_writer, self.validation_writer = self.validation_writer, self.swa_validation_writer\n        print('swa', end=' ')\n        self.calculate_test_validations(steps)\n        self.validation_writer = true_validation_writer\n        for (old, w) in zip(backup, self.model.weights):\n            w.assign(old)\n\n    def calculate_test_validations(self, steps):\n        for metric in self.test_metrics:\n            metric.reset()\n        for (x, y, z, q, m) in self.validation_dataset:\n            if self.strategy is not None:\n                metrics = self.strategy_calculate_test_summaries_inner_loop(\n                    x, y, z, q, m)\n            else:\n                metrics = self.calculate_test_summaries_inner_loop(\n                    x, y, z, q, m)\n            for acc, val in zip(self.test_metrics, metrics):\n                acc.accumulate(val)\n        with self.validation_writer.as_default():\n            for metric in self.test_metrics:\n                tf.summary.scalar(metric.long_name, metric.get(), step=steps)\n        self.validation_writer.flush()\n\n        print(\"step {}, validation:\".format(steps), end='')\n        for metric in self.test_metrics:\n            print(\" {}={:g}{}\".format(metric.short_name, metric.get(),\n                                      metric.suffix),\n                  end='')\n        print()\n\n    @tf.function()\n    def compute_update_ratio(self, before_weights, after_weights, steps):\n        \"\"\"Compute the ratio of gradient norm to weight norm.\n\n        Adapted from https://github.com/tensorflow/minigo/blob/c923cd5b11f7d417c9541ad61414bf175a84dc31/dual_net.py#L567\n        \"\"\"\n        deltas = [\n            after - before\n            for after, before in zip(after_weights, before_weights)\n        ]\n        delta_norms = [tf.math.reduce_euclidean_norm(d) for d in deltas]\n        weight_norms = [\n            tf.math.reduce_euclidean_norm(w) for w in before_weights\n        ]\n        ratios = [(tensor.name, tf.cond(w != 0., lambda: d / w, lambda: -1.))\n                  for d, w, tensor in zip(delta_norms, weight_norms,\n                                          self.model.weights)\n                  if not 'moving' in tensor.name]\n        for name, ratio in ratios:\n            tf.summary.scalar('update_ratios/' + name, ratio, step=steps)\n        # Filtering is hard, so just push infinities/NaNs to an unreasonably large value.\n        ratios = [\n            tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299,\n                    lambda: 200.) for (_, r) in ratios\n        ]\n        tf.summary.histogram('update_ratios_log10',\n                             tf.stack(ratios),\n                             buckets=1000,\n                             step=steps)\n\n    def update_swa(self):\n        num = self.swa_count.read_value()\n        for (w, swa) in zip(self.model.weights, self.swa_weights):\n            swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() *\n                       (1. / (num + 1.)))\n        self.swa_count.assign(min(num + 1., self.swa_max_n))\n\n    def save_swa_weights(self, filename):\n        backup = self.read_weights()\n        for (swa, w) in zip(self.swa_weights, self.model.weights):\n            w.assign(swa.read_value())\n        self.save_leelaz_weights(filename)\n        for (old, w) in zip(backup, self.model.weights):\n            w.assign(old)\n\n    def save_leelaz_weights(self, filename):\n        numpy_weights = []\n        for weight in self.model.weights:\n            numpy_weights.append([weight.name, weight.numpy()])\n        self.net.fill_net_v2(numpy_weights)\n        self.net.save_proto(filename)\n\n    def batch_norm(self, input, name, scale=False):\n        if self.renorm_enabled:\n            clipping = {\n                \"rmin\": 1.0 / self.renorm_max_r,\n                \"rmax\": self.renorm_max_r,\n                \"dmax\": self.renorm_max_d\n            }\n            return tf.keras.layers.BatchNormalization(\n                epsilon=1e-5,\n                axis=1,\n                fused=False,\n                center=True,\n                scale=scale,\n                renorm=True,\n                renorm_clipping=clipping,\n                renorm_momentum=self.renorm_momentum,\n                name=name)(input)\n        else:\n            return tf.keras.layers.BatchNormalization(\n                epsilon=1e-5,\n                axis=1,\n                center=True,\n                scale=scale,\n                virtual_batch_size=self.virtual_batch_size,\n                name=name)(input)\n\n    def squeeze_excitation(self, inputs, channels, name):\n        assert channels % self.SE_ratio == 0\n\n        pooled = tf.keras.layers.GlobalAveragePooling2D(\n            data_format='channels_first')(inputs)\n        squeezed = tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)(\n            tf.keras.layers.Dense(channels // self.SE_ratio,\n                                  kernel_initializer='glorot_normal',\n                                  kernel_regularizer=self.l2reg,\n                                  name=name + '/se/dense1')(pooled))\n        excited = tf.keras.layers.Dense(2 * channels,\n                                        kernel_initializer='glorot_normal',\n                                        kernel_regularizer=self.l2reg,\n                                        name=name + '/se/dense2')(squeezed)\n        return ApplySqueezeExcitation()([inputs, excited])\n\n    def conv_block(self,\n                   inputs,\n                   filter_size,\n                   output_channels,\n                   name,\n                   bn_scale=False):\n        conv = tf.keras.layers.Conv2D(output_channels,\n                                      filter_size,\n                                      use_bias=False,\n                                      padding='same',\n                                      kernel_initializer='glorot_normal',\n                                      kernel_regularizer=self.l2reg,\n                                      data_format='channels_first',\n                                      name=name + '/conv2d')(inputs)\n        return tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)(\n            self.batch_norm(conv, name=name + '/bn', scale=bn_scale))\n\n    def residual_block(self, inputs, channels, name):\n        conv1 = tf.keras.layers.Conv2D(channels,\n                                       3,\n                                       use_bias=False,\n                                       padding='same',\n                                       kernel_initializer='glorot_normal',\n                                       kernel_regularizer=self.l2reg,\n                                       data_format='channels_first',\n                                       name=name + '/1/conv2d')(inputs)\n        out1 = tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)(\n            self.batch_norm(conv1, name + '/1/bn', scale=False))\n        conv2 = tf.keras.layers.Conv2D(channels,\n                                       3,\n                                       use_bias=False,\n                                       padding='same',\n                                       kernel_initializer='glorot_normal',\n                                       kernel_regularizer=self.l2reg,\n                                       data_format='channels_first',\n                                       name=name + '/2/conv2d')(out1)\n\n        out2 = self.squeeze_excitation(self.batch_norm(conv2,\n                                                       name + '/2/bn',\n                                                       scale=True),\n                                       channels,\n                                       name=name + '/se')\n        return tf.keras.layers.Activation(self.DEFAULT_ACTIVATION)(\n            tf.keras.layers.add([inputs, out2]))\n\n    @staticmethod\n    def split_heads(inputs, batch_size: int, num_heads: int, depth: int):\n        if num_heads < 2:\n            return inputs\n        reshaped = tf.reshape(inputs, (batch_size, 64, num_heads, depth))\n        # (batch_size, num_heads, 64, depth)\n        return tf.transpose(reshaped, perm=[0, 2, 1, 3])\n\n    def scaled_dot_product_attention(self,\n                                     q,\n                                     k,\n                                     v,\n                                     name: str = None,\n                                     inputs=None):\n\n        # 0 h 64 d, 0 h d 64\n        matmul_qk = tf.matmul(q, k, transpose_b=True)\n        dk = tf.cast(tf.shape(k)[-1], self.model_dtype)\n        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n        heads = scaled_attention_logits.shape[1]\n\n        if self.use_smolgen:\n            smolgen_weights = self.smolgen_weights(\n                inputs,\n                heads,\n                self.smolgen_hidden_channels,\n                self.smolgen_hidden_sz,\n                self.smolgen_gen_sz,\n                name=name + '/smolgen',\n                activation=self.smolgen_activation)\n            scaled_attention_logits = scaled_attention_logits + smolgen_weights\n\n        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n        output = tf.matmul(attention_weights, v)\n        return output, scaled_attention_logits\n\n    # multi-head attention in encoder blocks\n\n    def mha(self, inputs, emb_size: int, d_model: int, num_heads: int,\n            initializer, name: str):\n        assert d_model % num_heads == 0\n        depth = d_model // num_heads\n        # query, key, and value vectors for self-attention\n        # inputs b, 64, sz\n        q = tf.keras.layers.Dense(d_model,\n                                  name=name + '/wq',\n                                  kernel_initializer='glorot_normal')(inputs)\n        k = tf.keras.layers.Dense(d_model,\n                                  name=name + '/wk',\n                                  kernel_initializer='glorot_normal')(inputs)\n        v = tf.keras.layers.Dense(d_model,\n                                  name=name + '/wv',\n                                  kernel_initializer=initializer)(inputs)\n\n        # split q, k and v into smaller vectors of size 'depth' -- one for each head in multi-head attention\n        batch_size = tf.shape(q)[0]\n        q = self.split_heads(q, batch_size, num_heads, depth)\n        k = self.split_heads(k, batch_size, num_heads, depth)\n        v = self.split_heads(v, batch_size, num_heads, depth)\n\n        scaled_attention, attention_weights = self.scaled_dot_product_attention(\n            q, k, v, name=name, inputs=inputs)\n        if num_heads > 1:\n            scaled_attention = tf.transpose(scaled_attention,\n                                            perm=[0, 2, 1, 3])\n            scaled_attention = tf.reshape(\n                scaled_attention,\n                (batch_size, -1, d_model))  # concatenate heads\n\n        # final dense layer\n        output = tf.keras.layers.Dense(\n            emb_size, name=name + \"/dense\",\n            kernel_initializer=initializer)(scaled_attention)\n        return output, attention_weights\n\n    # 2-layer dense feed-forward network in encoder blocks\n    def ffn(self, inputs, emb_size: int, dff: int, initializer, name: str):\n        if self.encoder_layers > 0:\n            activation = square_relu if self.square_relu_ffn else tf.keras.activations.get(\n                self.DEFAULT_ACTIVATION)\n        else:\n            activation = \"selu\"\n        dense1 = tf.keras.layers.Dense(dff,\n                                       name=name + \"/dense1\",\n                                       kernel_initializer=initializer,\n                                       activation=activation)(inputs)\n        out = tf.keras.layers.Dense(emb_size,\n                                    name=name + \"/dense2\",\n                                    kernel_initializer=initializer)(dense1)\n        return out\n\n    def encoder_layer(self, inputs, emb_size: int, d_model: int,\n                      num_heads: int, dff: int, name: str):\n        initializer = None\n        if self.encoder_layers > 0:\n            # DeepNorm\n            alpha = tf.cast(tf.math.pow(2. * self.encoder_layers, 0.25),\n                            self.model_dtype)\n            beta = tf.cast(tf.math.pow(8. * self.encoder_layers, -0.25),\n                           self.model_dtype)\n            xavier_norm = tf.keras.initializers.VarianceScaling(\n                scale=beta, mode='fan_avg', distribution='truncated_normal')\n            initializer = xavier_norm\n        else:\n            alpha = 1\n            initializer = \"glorot_normal\"\n        # multihead attention\n        attn_output, attn_wts = self.mha(inputs,\n                                         emb_size,\n                                         d_model,\n                                         num_heads,\n                                         initializer,\n                                         name=name + \"/mha\")\n        # dropout for weight regularization\n        attn_output = tf.keras.layers.Dropout(self.dropout_rate,\n                                              name=name +\n                                              \"/dropout1\")(attn_output)\n        # skip connection + layernorm\n        out1 = tf.keras.layers.LayerNormalization(\n            epsilon=1e-6, name=name + \"/ln1\")(inputs * alpha + attn_output)\n        # feed-forward network\n        ffn_output = self.ffn(out1,\n                              emb_size,\n                              dff,\n                              initializer,\n                              name=name + \"/ffn\")\n        ffn_output = tf.keras.layers.Dropout(self.dropout_rate,\n                                             name=name +\n                                             \"/dropout2\")(ffn_output)\n        out2 = tf.keras.layers.LayerNormalization(\n            epsilon=1e-6, name=name + \"/ln2\")(out1 * alpha + ffn_output)\n        return out2, attn_wts\n\n    def smolgen_weights(self,\n                        inputs,\n                        heads: int,\n                        hidden_channels: int,\n                        hidden_sz: int,\n                        gen_sz: int,\n                        name: str,\n                        activation='swish'):\n        compressed = tf.keras.layers.Dense(hidden_channels,\n                                           name=name + '/compress',\n                                           use_bias=False)(inputs)\n        compressed = tf.reshape(compressed, [-1, 64 * hidden_channels])\n\n        hidden = tf.keras.layers.Dense(hidden_sz,\n                                       name=name + '/hidden1_dense',\n                                       activation=activation)(compressed)\n        hidden = tf.keras.layers.LayerNormalization(name=name +\n                                                    '/hidden1_ln')(hidden)\n\n        gen_from = tf.keras.layers.Dense(heads * gen_sz,\n                                         name=name + '/gen_from',\n                                         activation=activation)(hidden)\n        gen_from = tf.keras.layers.LayerNormalization(name=name +\n                                                      '/gen_from_ln')(gen_from)\n        gen_from = tf.reshape(gen_from, [-1, heads, gen_sz])\n        out = self.smol_weight_gen_dense(gen_from)\n        return tf.reshape(out, [-1, heads, 64, 64])\n\n    def create_residual_body(self, inputs):\n        flow = self.conv_block(inputs,\n                               filter_size=3,\n                               output_channels=self.RESIDUAL_FILTERS,\n                               name='input',\n                               bn_scale=True)\n        for i in range(self.RESIDUAL_BLOCKS):\n            flow = self.residual_block(flow,\n                                       self.RESIDUAL_FILTERS,\n                                       name='residual_{}'.format(i + 1))\n        return flow\n\n    def create_encoder_body(self, inputs, embedding_size):\n        # Policy head\n        assert self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION\n\n        # do some input processing\n        if self.use_smolgen:\n            self.smol_weight_gen_dense = tf.keras.layers.Dense(\n                64 * 64, name='smol_weight_gen', use_bias=False)\n\n        flow = tf.transpose(inputs, perm=[0, 2, 3, 1])\n        flow = tf.reshape(flow, [-1, 64, tf.shape(inputs)[1]])\n        # add positional encoding for each square to the input\n        if self.arc_encoding:\n            self.POS_ENC = apm.make_pos_enc()\n            positional_encoding = tf.broadcast_to(\n                tf.convert_to_tensor(self.POS_ENC, dtype=flow.dtype),\n                [tf.shape(flow)[0], 64,\n                 tf.shape(self.POS_ENC)[2]])\n            flow = tf.concat([flow, positional_encoding], axis=2)\n\n        # square embedding\n        flow = tf.keras.layers.Dense(embedding_size,\n                                     kernel_initializer='glorot_normal',\n                                     kernel_regularizer=self.l2reg,\n                                     activation=self.DEFAULT_ACTIVATION,\n                                     name='embedding')(flow)\n\n        # !!! input gate\n        flow = ma_gating(flow, name='embedding')\n        attn_wts = []\n        for i in range(self.encoder_layers):\n            flow, attn_wts_l = self.encoder_layer(flow,\n                                                  embedding_size,\n                                                  self.encoder_d_model,\n                                                  self.encoder_heads,\n                                                  self.encoder_dff,\n                                                  name='encoder_{}'.format(i +\n                                                                           1))\n            attn_wts.append(attn_wts_l)\n        return flow, attn_wts\n\n    def apply_promotion_logits(self, queries, keys, attn_wts):\n        # PAWN PROMOTION: create promotion logits using scalar offsets generated from the promotion-rank keys\n        dk = tf.math.sqrt(tf.cast(tf.shape(keys)[-1],\n                                  self.model_dtype))  # constant for scaling\n        promotion_keys = keys[:, -8:, :]\n        # queen, rook, bishop, knight order\n        promotion_offsets = tf.keras.layers.Dense(\n            4,\n            kernel_initializer='glorot_normal',\n            name='policy/attention/ppo',\n            use_bias=False)(promotion_keys)\n        promotion_offsets = tf.transpose(promotion_offsets,\n                                         perm=[0, 2, 1]) * dk  # Bx4x8\n        # knight offset is added to the other three\n        promotion_offsets = promotion_offsets[:, :\n                                              3, :] + promotion_offsets[:,\n                                                                        3:4, :]\n\n        # POLICY SELF-ATTENTION: self-attention weights are interpreted as from->to policy\n        matmul_qk = tf.matmul(\n            queries, keys,\n            transpose_b=True)  # Bx64x64 (from 64 queries, 64 keys)\n\n        # q, r, and b promotions are offset from the default promotion logit (knight)\n        n_promo_logits = matmul_qk[:, -16:-8,\n                                   -8:]  # default traversals from penultimate rank to promotion rank\n        q_promo_logits = tf.expand_dims(n_promo_logits +\n                                        promotion_offsets[:, 0:1, :],\n                                        axis=3)  # Bx8x8x1\n        r_promo_logits = tf.expand_dims(n_promo_logits +\n                                        promotion_offsets[:, 1:2, :],\n                                        axis=3)\n        b_promo_logits = tf.expand_dims(n_promo_logits +\n                                        promotion_offsets[:, 2:3, :],\n                                        axis=3)\n        promotion_logits = tf.concat(\n            [q_promo_logits, r_promo_logits, b_promo_logits],\n            axis=3)  # Bx8x8x3\n        promotion_logits = tf.reshape(\n            promotion_logits,\n            [-1, 8, 24])  # logits now alternate a7a8q,a7a8r,a7a8b,...,\n\n        # scale the logits by dividing them by sqrt(d_model) to stabilize gradients\n        promotion_logits = promotion_logits / dk  # Bx8x24 (8 from-squares, 3x8 promotions)\n        policy_attn_logits = matmul_qk / dk  # Bx64x64 (64 from-squares, 64 to-squares)\n\n        attn_wts.append(promotion_logits)\n        attn_wts.append(policy_attn_logits)\n\n        # APPLY POLICY MAP: output becomes Bx1856\n        h_fc1 = ApplyAttentionPolicyMap()(policy_attn_logits, promotion_logits)\n        return h_fc1\n\n    def construct_net(self, inputs, name=''):\n\n        if self.encoder_layers > 0:\n            flow, attn_wts = self.create_encoder_body(inputs,\n                                                      self.embedding_size)\n        else:\n            flow = self.create_residual_body(inputs)\n\n        # Policy head\n        if self.POLICY_HEAD == pb.NetworkFormat.POLICY_CONVOLUTION:\n            conv_pol = self.conv_block(flow,\n                                       filter_size=3,\n                                       output_channels=self.RESIDUAL_FILTERS,\n                                       name='policy1')\n            conv_pol2 = tf.keras.layers.Conv2D(\n                80,\n                3,\n                use_bias=True,\n                padding='same',\n                kernel_initializer='glorot_normal',\n                kernel_regularizer=self.l2reg,\n                bias_regularizer=self.l2reg,\n                data_format='channels_first',\n                name='policy')(conv_pol)\n            h_fc1 = ApplyPolicyMap()(conv_pol2)\n        elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_CLASSICAL:\n            conv_pol = self.conv_block(flow,\n                                       filter_size=1,\n                                       output_channels=self.policy_channels,\n                                       name='policy')\n            h_conv_pol_flat = tf.keras.layers.Flatten()(conv_pol)\n            h_fc1 = tf.keras.layers.Dense(1858,\n                                          kernel_initializer='glorot_normal',\n                                          kernel_regularizer=self.l2reg,\n                                          bias_regularizer=self.l2reg,\n                                          name='policy/dense')(h_conv_pol_flat)\n        elif self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION:\n            if self.encoder_layers == 0:\n                attn_wts = []\n            if self.RESIDUAL_BLOCKS > 0:\n                # transpose and reshape\n                tokens = tf.transpose(flow, perm=[0, 2, 3, 1])\n                tokens = tf.reshape(tokens, [-1, 64, self.RESIDUAL_FILTERS])\n                embed_activation = 'selu'\n            else:\n                tokens = flow\n                embed_activation = self.DEFAULT_ACTIVATION\n            # SQUARE EMBEDDING: found to increase attention head performance\n            tokens = tf.keras.layers.Dense(self.pol_embedding_size,\n                                           kernel_initializer='glorot_normal',\n                                           kernel_regularizer=self.l2reg,\n                                           activation=embed_activation,\n                                           name='policy/embedding')(tokens)\n            if self.RESIDUAL_BLOCKS > 0:\n                # ENCODER LAYERS: intermediate layers of self-attention with residual connections\n                for i in range(self.pol_encoder_layers):\n                    tokens, attn_wts_l = self.encoder_layer(\n                        tokens,\n                        self.pol_embedding_size,\n                        self.pol_encoder_d_model,\n                        self.pol_encoder_heads,\n                        self.pol_encoder_dff,\n                        name='policy/enc_layer_{}'.format(i + 1))\n                    attn_wts.append(attn_wts_l)\n\n            # create queries and keys for policy self-attention\n            queries = tf.keras.layers.Dense(self.policy_d_model,\n                                            kernel_initializer='glorot_normal',\n                                            name='policy/attention/wq')(tokens)\n            keys = tf.keras.layers.Dense(self.policy_d_model,\n                                         kernel_initializer='glorot_normal',\n                                         name='policy/attention/wk')(tokens)\n\n            h_fc1 = self.apply_promotion_logits(queries, keys, attn_wts)\n\n        else:\n            raise ValueError(\"Unknown policy head type {}\".format(\n                self.POLICY_HEAD))\n\n        # Value head\n        if self.encoder_layers > 0:\n            conv_val = tf.keras.layers.Dense(\n                self.val_embedding_size,\n                kernel_initializer='glorot_normal',\n                kernel_regularizer=self.l2reg,\n                activation=self.DEFAULT_ACTIVATION,\n                name='value/embedding')(flow)\n        else:\n            conv_val = self.conv_block(flow,\n                                       filter_size=1,\n                                       output_channels=32,\n                                       name='value')\n\n        h_conv_val_flat = tf.keras.layers.Flatten()(conv_val)\n        h_fc2 = tf.keras.layers.Dense(128,\n                                      kernel_initializer='glorot_normal',\n                                      kernel_regularizer=self.l2reg,\n                                      activation=self.DEFAULT_ACTIVATION,\n                                      name='value/dense1')(h_conv_val_flat)\n        if self.wdl:\n            h_fc3 = tf.keras.layers.Dense(3,\n                                          kernel_initializer='glorot_normal',\n                                          kernel_regularizer=self.l2reg,\n                                          bias_regularizer=self.l2reg,\n                                          name='value/dense2')(h_fc2)\n        else:\n            h_fc3 = tf.keras.layers.Dense(1,\n                                          kernel_initializer='glorot_normal',\n                                          kernel_regularizer=self.l2reg,\n                                          activation='tanh',\n                                          name='value/dense2')(h_fc2)\n\n        # Moves left head\n        if self.moves_left:\n            if self.encoder_layers > 0:\n                conv_mov = tf.keras.layers.Dense(\n                    self.mov_embedding_size,\n                    kernel_initializer='glorot_normal',\n                    kernel_regularizer=self.l2reg,\n                    activation=self.DEFAULT_ACTIVATION,\n                    name='moves_left/embedding')(flow)\n            else:\n                conv_mov = self.conv_block(flow,\n                                           filter_size=1,\n                                           output_channels=8,\n                                           name='moves_left')\n            h_conv_mov_flat = tf.keras.layers.Flatten()(conv_mov)\n            h_fc4 = tf.keras.layers.Dense(\n                128,\n                kernel_initializer='glorot_normal',\n                kernel_regularizer=self.l2reg,\n                activation=self.DEFAULT_ACTIVATION,\n                name='moves_left/dense1')(h_conv_mov_flat)\n\n            h_fc5 = tf.keras.layers.Dense(1,\n                                          kernel_initializer='glorot_normal',\n                                          kernel_regularizer=self.l2reg,\n                                          activation='relu',\n                                          name='moves_left/dense2')(h_fc4)\n        else:\n            h_fc5 = None\n\n        # attention weights added as optional output for analysis -- ignored by backend\n        if self.POLICY_HEAD == pb.NetworkFormat.POLICY_ATTENTION:\n            if self.moves_left:\n                outputs = [h_fc1, h_fc3, h_fc5, attn_wts]\n            else:\n                outputs = [h_fc1, h_fc3, attn_wts]\n        elif self.moves_left:\n            outputs = [h_fc1, h_fc3, h_fc5]\n        else:\n            outputs = [h_fc1, h_fc3]\n\n        return outputs\n"
  },
  {
    "path": "tf/train.py",
    "content": "#!/usr/bin/env python3\n#\n#    This file is part of Leela Zero.\n#    Copyright (C) 2017 Gian-Carlo Pascutto\n#\n#    Leela Zero is free software: you can redistribute it and/or modify\n#    it under the terms of the GNU General Public License as published by\n#    the Free Software Foundation, either version 3 of the License, or\n#    (at your option) any later version.\n#\n#    Leela Zero is distributed in the hope that it will be useful,\n#    but WITHOUT ANY WARRANTY; without even the implied warranty of\n#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#    GNU General Public License for more details.\n#\n#    You should have received a copy of the GNU General Public License\n#    along with Leela Zero.  If not, see <http://www.gnu.org/licenses/>.\n\nimport argparse\nimport os\nimport yaml\nimport sys\nimport glob\nimport gzip\nimport random\nimport multiprocessing as mp\nfrom chunkparser import ChunkParser\n\nSKIP = 32\n\n\ndef get_chunks(data_prefix):\n    return glob.glob(data_prefix + \"*.gz\")\n\n\ndef get_all_chunks(path):\n    if isinstance(path, list):\n        print(\"getting chunks for\", path)\n        chunks = []\n        for i in path:\n            chunks += get_all_chunks(i)\n        return chunks\n    chunks = []\n    for d in glob.glob(path):\n        chunks += get_chunks(d)\n    print(\"got\", len(chunks), \"chunks for\", path)\n    return chunks\n\n\ndef get_latest_chunks(path, num_chunks, allow_less, sort_key_fn):\n    chunks = get_all_chunks(path)\n    if len(chunks) < num_chunks:\n        if allow_less:\n            print(\"sorting {} chunks...\".format(len(chunks)),\n                  end='',\n                  flush=True)\n            chunks.sort(key=sort_key_fn, reverse=True)\n            print(\"[done]\")\n            print(\"{} - {}\".format(os.path.basename(chunks[-1]),\n                                   os.path.basename(chunks[0])))\n            random.shuffle(chunks)\n            return chunks\n        else:\n            print(\"Not enough chunks {}\".format(len(chunks)))\n            sys.exit(1)\n\n    print(\"sorting {} chunks...\".format(len(chunks)), end='', flush=True)\n    chunks.sort(key=sort_key_fn, reverse=True)\n    print(\"[done]\")\n    chunks = chunks[:num_chunks]\n    print(\"{} - {}\".format(os.path.basename(chunks[-1]),\n                           os.path.basename(chunks[0])))\n    random.shuffle(chunks)\n    return chunks\n\n\ndef identity_function(name):\n    return name\n\n\ndef game_number_for_name(name):\n    num_str = os.path.basename(name).upper().strip(\n        \"ABCDEFGHIJKLMNOPQRSTUVWXYZ_-.\")\n    return int(num_str)\n\n\ndef get_input_mode(cfg):\n    import proto.net_pb2 as pb\n    input_mode = cfg['model'].get('input_type', 'classic')\n\n    if input_mode == \"classic\":\n        return pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE\n    elif input_mode == \"frc_castling\":\n        return pb.NetworkFormat.INPUT_112_WITH_CASTLING_PLANE\n    elif input_mode == \"canonical\":\n        return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION\n    elif input_mode == \"canonical_100\":\n        return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES\n    elif input_mode == \"canonical_armageddon\":\n        return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_HECTOPLIES_ARMAGEDDON\n    elif input_mode == \"canonical_v2\":\n        return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2\n    elif input_mode == \"canonical_v2_armageddon\":\n        return pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON\n    else:\n        raise ValueError(\"Unknown input mode format: {}\".format(input_mode))\n\n\ndef main(cmd):\n    cfg = yaml.safe_load(cmd.cfg.read())\n    print(yaml.dump(cfg, default_flow_style=False))\n\n    num_chunks = cfg['dataset']['num_chunks']\n    allow_less = cfg['dataset'].get('allow_less_chunks', False)\n    train_ratio = cfg['dataset']['train_ratio']\n    num_train = int(num_chunks * train_ratio)\n    num_test = num_chunks - num_train\n    sort_type = cfg['dataset'].get('sort_type', 'mtime')\n    if sort_type == 'mtime':\n        sort_key_fn = os.path.getmtime\n    elif sort_type == 'number':\n        sort_key_fn = game_number_for_name\n    elif sort_type == 'name':\n        sort_key_fn = identity_function\n    else:\n        raise ValueError('Unknown dataset sort_type: {}'.format(sort_type))\n    if 'input_test' in cfg['dataset']:\n        train_chunks = get_latest_chunks(cfg['dataset']['input_train'],\n                                         num_train, allow_less, sort_key_fn)\n        test_chunks = get_latest_chunks(cfg['dataset']['input_test'], num_test,\n                                        allow_less, sort_key_fn)\n    else:\n        chunks = get_latest_chunks(cfg['dataset']['input'], num_chunks,\n                                   allow_less, sort_key_fn)\n        if allow_less:\n            num_train = int(len(chunks) * train_ratio)\n            num_test = len(chunks) - num_train\n        train_chunks = chunks[:num_train]\n        test_chunks = chunks[num_train:]\n\n    shuffle_size = cfg['training']['shuffle_size']\n    total_batch_size = cfg['training']['batch_size']\n    batch_splits = cfg['training'].get('num_batch_splits', 1)\n    train_workers = cfg['dataset'].get('train_workers', None)\n    test_workers = cfg['dataset'].get('test_workers', None)\n    if total_batch_size % batch_splits != 0:\n        raise ValueError('num_batch_splits must divide batch_size evenly')\n    split_batch_size = total_batch_size // batch_splits\n\n    diff_focus_min = cfg['training'].get('diff_focus_min', 1)\n    diff_focus_slope = cfg['training'].get('diff_focus_slope', 0)\n    diff_focus_q_weight = cfg['training'].get('diff_focus_q_weight', 6.0)\n    diff_focus_pol_scale = cfg['training'].get('diff_focus_pol_scale', 3.5)\n\n    root_dir = os.path.join(cfg['training']['path'], cfg['name'])\n    if not os.path.exists(root_dir):\n        os.makedirs(root_dir)\n\n    train_parser = ChunkParser(train_chunks,\n                               get_input_mode(cfg),\n                               shuffle_size=shuffle_size,\n                               sample=SKIP,\n                               batch_size=split_batch_size,\n                               diff_focus_min=diff_focus_min,\n                               diff_focus_slope=diff_focus_slope,\n                               diff_focus_q_weight=diff_focus_q_weight,\n                               diff_focus_pol_scale=diff_focus_pol_scale,\n                               workers=train_workers)\n    test_shuffle_size = int(shuffle_size * (1.0 - train_ratio))\n    # no diff focus for test_parser\n    test_parser = ChunkParser(test_chunks,\n                              get_input_mode(cfg),\n                              shuffle_size=test_shuffle_size,\n                              sample=SKIP,\n                              batch_size=split_batch_size,\n                              workers=test_workers)\n    if 'input_validation' in cfg['dataset']:\n        valid_chunks = get_all_chunks(cfg['dataset']['input_validation'])\n        validation_parser = ChunkParser(valid_chunks,\n                                        get_input_mode(cfg),\n                                        sample=1,\n                                        batch_size=split_batch_size,\n                                        workers=0)\n\n    import tensorflow as tf\n    from chunkparsefunc import parse_function\n    from tfprocess import TFProcess\n    tfprocess = TFProcess(cfg)\n    train_dataset = tf.data.Dataset.from_generator(\n        train_parser.parse,\n        output_types=(tf.string, tf.string, tf.string, tf.string, tf.string))\n    train_dataset = train_dataset.map(parse_function)\n    test_dataset = tf.data.Dataset.from_generator(\n        test_parser.parse,\n        output_types=(tf.string, tf.string, tf.string, tf.string, tf.string))\n    test_dataset = test_dataset.map(parse_function)\n\n    validation_dataset = None\n    if 'input_validation' in cfg['dataset']:\n        validation_dataset = tf.data.Dataset.from_generator(\n            validation_parser.sequential,\n            output_types=(tf.string, tf.string, tf.string, tf.string,\n                          tf.string))\n        validation_dataset = validation_dataset.map(parse_function)\n\n    if tfprocess.strategy is None:  #Mirrored strategy appends prefetch itself with a value depending on number of replicas\n        train_dataset = train_dataset.prefetch(4)\n        test_dataset = test_dataset.prefetch(4)\n        if validation_dataset is not None:\n            validation_dataset = validation_dataset.prefetch(4)\n    else:\n        options = tf.data.Options()\n        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n        train_dataset = train_dataset.with_options(options)\n        test_dataset = test_dataset.with_options(options)\n        if validation_dataset is not None:\n            validation_dataset = validation_dataset.with_options(options)\n    tfprocess.init(train_dataset, test_dataset, validation_dataset)\n\n    tfprocess.restore()\n\n    # If number of test positions is not given\n    # sweeps through all test chunks statistically\n    # Assumes average of 10 samples per test game.\n    # For simplicity, testing can use the split batch size instead of total batch size.\n    # This does not affect results, because test results are simple averages that are independent of batch size.\n    num_evals = cfg['training'].get('num_test_positions',\n                                    len(test_chunks) * 10)\n    num_evals = max(1, num_evals // split_batch_size)\n    print(\"Using {} evaluation batches\".format(num_evals))\n    tfprocess.total_batch_size = total_batch_size\n    tfprocess.process_loop(total_batch_size,\n                           num_evals,\n                           batch_splits=batch_splits)\n\n    if cmd.output is not None:\n        if cfg['training'].get('swa_output', False):\n            tfprocess.save_swa_weights(cmd.output)\n        else:\n            tfprocess.save_leelaz_weights(cmd.output)\n\n    train_parser.shutdown()\n    test_parser.shutdown()\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n    'Tensorflow pipeline for training Leela Chess.')\n    argparser.add_argument('--cfg',\n                           type=argparse.FileType('r'),\n                           help='yaml configuration with training parameters')\n    argparser.add_argument('--output',\n                           type=str,\n                           help='file to store weights in')\n\n    #mp.set_start_method('spawn')\n    main(argparser.parse_args())\n    mp.freeze_support()\n"
  },
  {
    "path": "tf/update_steps.py",
    "content": "#!/usr/bin/env python3\nimport argparse\nimport os\nimport yaml\nimport sys\nimport tensorflow as tf\nfrom tfprocess import TFProcess\n\nSTART_FROM = 0\n\n\ndef main(cmd):\n    cfg = yaml.safe_load(cmd.cfg.read())\n    print(yaml.dump(cfg, default_flow_style=False))\n\n    root_dir = os.path.join(cfg['training']['path'], cfg['name'])\n    if not os.path.exists(root_dir):\n        os.makedirs(root_dir)\n\n    tfprocess = TFProcess(cfg)\n    tfprocess.init_net()\n\n    tfprocess.restore()\n\n    START_FROM = cmd.start\n\n    tfprocess.global_step.assign(START_FROM)\n    tfprocess.manager.save(checkpoint_number=START_FROM)\n\n\nif __name__ == \"__main__\":\n    argparser = argparse.ArgumentParser(description=\\\n    'Convert current checkpoint to new step count.')\n    argparser.add_argument('--cfg',\n                           type=argparse.FileType('r'),\n                           help='yaml configuration with training parameters')\n    argparser.add_argument('--start',\n                           type=int,\n                           default=0,\n                           help='Offset to set global_step to.')\n\n    main(argparser.parse_args())\n"
  }
]