Showing preview only (1,174K chars total). Download the full file or copy to clipboard to get everything.
Repository: LeelaChessZero/lczero-training
Branch: master
Commit: a60c7d281ebc
Files: 219
Total size: 1.1 MB
Directory structure:
gitextract_k9ephqoc/
├── .clang-format
├── .gitignore
├── .gitmodules
├── AGENTS.md
├── README.md
├── csrc/
│ ├── loader/
│ │ ├── chunk_source/
│ │ │ ├── chunk_source.h
│ │ │ ├── chunk_source_view.h
│ │ │ ├── debug_chunk_source.cc
│ │ │ ├── debug_chunk_source.h
│ │ │ ├── rawfile_chunk_source.cc
│ │ │ ├── rawfile_chunk_source.h
│ │ │ ├── tar_chunk_source.cc
│ │ │ └── tar_chunk_source.h
│ │ ├── data_loader.cc
│ │ ├── data_loader.h
│ │ ├── data_loader_metrics.cc
│ │ ├── data_loader_metrics.h
│ │ ├── data_loader_test.cc
│ │ ├── frame_type.h
│ │ ├── loader_main.cpp
│ │ ├── pybind_module.cc
│ │ └── stages/
│ │ ├── chunk_rescorer.cc
│ │ ├── chunk_rescorer.h
│ │ ├── chunk_rescorer_test.cc
│ │ ├── chunk_source_loader.cc
│ │ ├── chunk_source_loader.h
│ │ ├── chunk_source_loader_test.cc
│ │ ├── chunk_source_splitter.cc
│ │ ├── chunk_source_splitter.h
│ │ ├── chunk_source_splitter_test.cc
│ │ ├── chunk_unpacker.cc
│ │ ├── chunk_unpacker.h
│ │ ├── chunk_unpacker_test.cc
│ │ ├── file_path_provider.cc
│ │ ├── file_path_provider.h
│ │ ├── file_path_provider_main.cc
│ │ ├── file_path_provider_test.cc
│ │ ├── join_stage.cc
│ │ ├── join_stage.h
│ │ ├── join_stage_test.cc
│ │ ├── position_sampling.cc
│ │ ├── position_sampling.h
│ │ ├── shuffling_chunk_pool.cc
│ │ ├── shuffling_chunk_pool.h
│ │ ├── shuffling_chunk_pool_test.cc
│ │ ├── shuffling_frame_sampler.cc
│ │ ├── shuffling_frame_sampler.h
│ │ ├── shuffling_frame_sampler_test.cc
│ │ ├── simple_chunk_extractor.cc
│ │ ├── simple_chunk_extractor.h
│ │ ├── simple_chunk_extractor_test.cc
│ │ ├── stage.cc
│ │ ├── stage.h
│ │ ├── stage_factory.cc
│ │ ├── stage_factory.h
│ │ ├── stage_factory_test.cc
│ │ ├── tensor_generator.cc
│ │ ├── tensor_generator.h
│ │ ├── tensor_generator_test.cc
│ │ └── training_chunk.h
│ ├── tools/
│ │ ├── dump_chunk_main.cc
│ │ ├── filter_chunks_main.cc
│ │ ├── position_weight_stats_main.cc
│ │ ├── rescore_chunk_main.cc
│ │ ├── result_distribution_main.cc
│ │ └── startpos_policy_distribution_main.cc
│ └── utils/
│ ├── gz.cc
│ ├── gz.h
│ ├── metrics/
│ │ ├── exponential_aggregator.h
│ │ ├── group.h
│ │ ├── load_metric.h
│ │ ├── load_metric_test.cc
│ │ ├── printer.h
│ │ ├── statistics_metric.h
│ │ └── stats_test.cc
│ ├── queue.h
│ ├── queue_test.cc
│ ├── stream_shuffler.cc
│ ├── stream_shuffler.h
│ ├── stream_shuffler_test.cc
│ ├── tensor.h
│ ├── tensor_test.cc
│ ├── thread_pool.h
│ ├── training_data_printer.cc
│ └── training_data_printer.h
├── docs/
│ ├── README.md
│ ├── architecture.md
│ ├── checkpoint_migration.md
│ ├── example.textproto
│ ├── heads.md
│ ├── index.md
│ ├── loader.md
│ ├── new_stage.md
│ ├── overview.md
│ ├── shuffling_pool_hanse_sampling.md
│ ├── training_tuple.md
│ ├── tui.md
│ └── weights_tool.md
├── init.sh
├── justfile
├── meson.build
├── native.ini
├── proto/
│ ├── checkpoint_migration_config.proto
│ ├── data_loader_config.proto
│ ├── export_config.proto
│ ├── metrics_config.proto
│ ├── model_config.proto
│ ├── root_config.proto
│ ├── stage_control.proto
│ ├── training_config.proto
│ └── training_metrics.proto
├── pyproject.toml
├── scripts/
│ ├── diff.py
│ ├── fixorder.py
│ ├── init.sh
│ ├── initsplit.py
│ ├── inittrainingname.py
│ ├── pack.py
│ ├── purge.py
│ ├── rescore.sh
│ ├── shuffle.py
│ ├── split.sh
│ ├── stage.sh
│ ├── unpack.py
│ └── upload.sh
├── src/
│ ├── lczero_training/
│ │ ├── __init__.py
│ │ ├── _lczero_training.pyi
│ │ ├── commands/
│ │ │ ├── __init__.py
│ │ │ ├── backfill_metrics.py
│ │ │ ├── common.py
│ │ │ ├── daemon.py
│ │ │ ├── dataloader_viz.py
│ │ │ ├── describe_training.py
│ │ │ ├── jax2leela.py
│ │ │ ├── leela2jax.py
│ │ │ ├── migrate_checkpoint.py
│ │ │ ├── overfit.py
│ │ │ ├── test_dataloader.py
│ │ │ ├── train.py
│ │ │ ├── training_eval.py
│ │ │ ├── training_init.py
│ │ │ ├── tui.py
│ │ │ ├── tune_lr.py
│ │ │ └── weights_tool.py
│ │ ├── convert/
│ │ │ ├── __init__.py
│ │ │ ├── jax_to_leela.py
│ │ │ ├── leela_pytree_visitor.py
│ │ │ ├── leela_to_jax.py
│ │ │ └── leela_to_modelconfig.py
│ │ ├── daemon/
│ │ │ ├── __init__.py
│ │ │ ├── daemon.py
│ │ │ ├── metrics.py
│ │ │ ├── metrics_base.py
│ │ │ ├── pipeline.py
│ │ │ ├── protocol/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── communicator.py
│ │ │ │ ├── messages.py
│ │ │ │ └── registry.py
│ │ │ └── rms_metrics.py
│ │ ├── dataloader/
│ │ │ └── __init__.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── embedding.py
│ │ │ ├── encoder.py
│ │ │ ├── loss_function.py
│ │ │ ├── model.py
│ │ │ ├── movesleft_head.py
│ │ │ ├── policy_head.py
│ │ │ ├── shared.py
│ │ │ ├── utils.py
│ │ │ └── value_head.py
│ │ ├── py.typed
│ │ ├── tests/
│ │ │ ├── test_protobuf.py
│ │ │ ├── test_protocol_registry.py
│ │ │ └── test_weights_tool.py
│ │ ├── tools/
│ │ │ ├── __init__.py
│ │ │ ├── weight_codecs.py
│ │ │ ├── weight_wrappers.py
│ │ │ └── weights_tool.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── backfill_metrics.py
│ │ │ ├── dataloader_probe.py
│ │ │ ├── describe.py
│ │ │ ├── eval.py
│ │ │ ├── init.py
│ │ │ ├── lr_schedule.py
│ │ │ ├── migrate_checkpoint.py
│ │ │ ├── optimizer.py
│ │ │ ├── overfit.py
│ │ │ ├── state.py
│ │ │ ├── tensorboard.py
│ │ │ ├── test_lr_schedule.py
│ │ │ ├── training.py
│ │ │ ├── tune_lr.py
│ │ │ └── utils.py
│ │ └── tui/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── app.tcss
│ │ ├── data_pipeline_pane.py
│ │ ├── dataloader_widgets.py
│ │ ├── log_pane.py
│ │ └── training_widgets.py
│ └── proto/
│ └── __init__.py
└── tf/
├── attention_policy_map.py
├── chunkparsefunc.py
├── chunkparser.py
├── configs/
│ └── example.yaml
├── decode_training.py
├── lc0_az_policy_map.py
├── make_model.py
├── model_to_net.py
├── net.py
├── net_to_model.py
├── policy_index.py
├── requirements.txt
├── shufflebuffer.py
├── start.sh
├── tfprocess.py
├── train.py
└── update_steps.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .clang-format
================================================
BasedOnStyle: Google
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*_pb2.py
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# protobuf stuff
tf/proto/
**/*_pb2.py
**/*_pb2.pyi
# Meson
builddir/
subprojects/
# LLM Agents
# Use AGENTS.md; symlink other files to it if needed.
.claude/
CLAUDE.local.md
.claude.json
CLAUDE.md
GEMINI.md
# IDE
.vscode/
.idea/
================================================
FILE: .gitmodules
================================================
[submodule "libs/lc0"]
path = libs/lc0
url = https://github.com/LeelaChessZero/lc0.git
================================================
FILE: AGENTS.md
================================================
# AGENTS.md
This repository contains training script for the Leela Chess Zero project.
They are being rewritten.
* Old code is located in the `tf/` directory.
* New python code is located in the `src/` directory.
* New C++ code is located in the `csrc/` directory.
The old code is Python/TensorFlow-based, new code is Python/JAX-based with
modules written in C++, operating through pybind11.
The build system for C++ code is meson. During development, the project is built
in the `builddir/`.
## Testing and Building
* C++ tests use GTest framework
* Do not insert Sleeps in tests, it slows down presubmit. Instead use e.g.
absl::Notification, or std::future
* Tests are defined in `meson.build` with `test()` function
* When debugging, don't forget to build them before running `meson test` or
`builddir/test`
* Run tests: `meson test -C builddir/`
* Python tests use `pytest` framework
* Do not add custom main function, exception catching to report errors, any
"test passed" messages etc. Use `pytest` fixtures and assertions.
* Build: `meson compile -C builddir/` from build directory
* Format code: `just format`
* There is a commit hook that runs `just pre-commit`, which runs tests and
checks formatting. You may want to run it before attempting to commit.
* We use Google C++ style guide.
* That means 80 columns.
* That means comments should be in full sentences with periods in the end.
* When conditional or loop fits one line, it must be written as one line
without braces, for example:
`if (condition) return value;`
* Prefer `absl` to `std` (e.g. `absl::c_` algorithms, `absl::Mutex`,
`absl::StrCat`, etc.)
* We use `uv` for Python package and venv management, and to running the
application.
* Run TUI app: `uv run tui --config=<path_to_config>`
* Do not attempt to run TUI — it messes up the Agent interface and session has
to be killed. Ask me to check it for you manually instead.
* Do not commit unless explicitly asked.
## IMPORTANT
* NEVER add `# type: ignore` or other ways to mask/silence errors instead of
fixing them.
* Rely on protobuf default values. DO NOT write code like
`config.has_foo() ? config.foo() : default_value;`
## Documentation
* Documentation is in the `docs/` directory.
* The contents is in [The index](docs/index.md)
================================================
FILE: README.md
================================================
# Training
The training pipeline resides in `tf`, this requires tensorflow running on linux (Ubuntu 16.04 in this case). (It can be made to work on windows too, but it takes more effort.)
## Installation
Install the requirements under `tf/requirements.txt`. And call `./init.sh` to compile the protobuf files.
## Data preparation
In order to start a training session you first need to download training data from https://storage.lczero.org/files/training_data/. Several chunks/games are packed into a tar file, and each tar file contains an hour worth of chunks. Preparing data requires the following steps:
```
wget https://storage.lczero.org/files/training_data/training-run1--20200711-2017.tar
tar -xzf training-run1--20200711-2017.tar
```
## Training pipeline
Now that the data is in the right format one can configure a training pipeline. This configuration is achieved through a yaml file, see `training/tf/configs/example.yaml`:
```yaml
%YAML 1.2
---
name: 'kb1-64x6' # ideally no spaces
gpu: 0 # gpu id to process on
dataset:
num_chunks: 100000 # newest nof chunks to parse
train_ratio: 0.90 # trainingset ratio
# For separated test and train data.
input_train: '/path/to/chunks/*/draw/' # supports glob
input_test: '/path/to/chunks/*/draw/' # supports glob
# For a one-shot run with all data in one directory.
# input: '/path/to/chunks/*/draw/'
training:
batch_size: 2048 # training batch
total_steps: 140000 # terminate after these steps
test_steps: 2000 # eval test set values after this many steps
# checkpoint_steps: 10000 # optional frequency for checkpointing before finish
shuffle_size: 524288 # size of the shuffle buffer
lr_values: # list of learning rates
- 0.02
- 0.002
- 0.0005
lr_boundaries: # list of boundaries
- 100000
- 130000
policy_loss_weight: 1.0 # weight of policy loss
value_loss_weight: 1.0 # weight of value loss
path: '/path/to/store/networks' # network storage dir
model:
filters: 64
residual_blocks: 6
...
```
The configuration is pretty self explanatory, if you're new to training I suggest looking at the [machine learning glossary](https://developers.google.com/machine-learning/glossary/) by google. Now you can invoke training with the following command:
```bash
./train.py --cfg configs/example.yaml --output /tmp/mymodel.txt
```
This will initialize the pipeline and start training a new neural network. You can view progress by invoking tensorboard:
```bash
tensorboard --logdir leelalogs
```
If you now point your browser at localhost:6006 you'll see the trainingprogress as the trainingsteps pass by. Have fun!
## Restoring models
The training pipeline will automatically restore from a previous model if it exists in your `training:path` as configured by your yaml config. For initializing from a raw `weights.txt` file you can use `training/tf/net_to_model.py`, this will create a checkpoint for you.
## Supervised training
Generating trainingdata from pgn files is currently broken and has low priority, feel free to create a PR.
## Building 2025-08 version.
1. Make sure `uv` and `justfile` are installed (plus `meson` and other stuff, potentially `protoc`).
2. `git submodule update`
3. `uv venv` (!important! do this before running meson; otherwise meson will build module for wrong python)
4. `uv sync`
5. `CXX=clang++ CC=clang uv run meson setup build/release/ --buildtype=release --native-file=native.ini` (clang is optional, should build fine with default compiler)
6. `just build-proto`
7. `meson compile -C build/release/`
8. `ln -s -T ../../build/release/_lczero_training.cpython-311-x86_64-linux-gnu.so src/lczero_training/_lczero_training.so`
14. Run it! `uv run tui --config docs/example.textproto`
================================================
FILE: csrc/loader/chunk_source/chunk_source.h
================================================
#pragma once
#include <cstddef>
#include <optional>
#include <string>
#include <vector>
#include "loader/frame_type.h"
namespace lczero {
namespace training {
// Interface for providing training data chunks.
// A chunk source provides access to one or more chunks of training data.
// It's assumed that all chunks in a source for one group for sorting purposes,
// therefore GetChunkSortKey() returns just one key for the entire source. This
// allows to know the key before reading/indexing the chunks.
class ChunkSource {
public:
virtual ~ChunkSource() = default;
// Returns a sort key (e.g. filename or a timestamp).
virtual std::string GetChunkSortKey() const = 0;
// Returns the number of chunks in this source.
virtual size_t GetChunkCount() const = 0;
// Returns the data for the chunk at the given index. Returns std::nullopt if
// the chunk could not be read or if the data size is not a multiple of the
// expected frame size.
virtual std::optional<std::vector<FrameType>> GetChunkData(size_t index) = 0;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/chunk_source_view.h
================================================
#pragma once
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "loader/chunk_source/chunk_source.h"
namespace lczero {
namespace training {
// ChunkSourceView provides a view over another ChunkSource.
// It exposes a remapped subset/order of chunks defined by indices into the
// underlying source. It does not own or copy the data; it forwards calls to
// the wrapped source.
class ChunkSourceView : public ChunkSource {
public:
// Constructs a view over an existing chunk source. The indices vector maps
// local indices in the view to indices of the underlying source.
ChunkSourceView(std::shared_ptr<ChunkSource> source,
std::vector<uint32_t> indices)
: source_(std::move(source)), indices_(std::move(indices)) {}
~ChunkSourceView() override = default;
private:
std::string GetChunkSortKey() const override {
return source_->GetChunkSortKey();
}
size_t GetChunkCount() const override { return indices_.size(); }
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {
if (index >= indices_.size()) return std::nullopt;
const size_t src_index = static_cast<size_t>(indices_[index]);
return source_->GetChunkData(src_index);
}
std::shared_ptr<ChunkSource> source_;
std::vector<uint32_t> indices_;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/debug_chunk_source.cc
================================================
#include "loader/chunk_source/debug_chunk_source.h"
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstring>
#include <random>
#include <utility>
#include "absl/hash/hash.h"
#include "absl/strings/str_format.h"
namespace lczero {
namespace training {
DebugChunkSource::DebugChunkSource(uint64_t id, double mean_chunk_count)
: id_(id), mean_chunk_count_(mean_chunk_count) {}
std::string DebugChunkSource::GetChunkSortKey() const {
return absl::StrFormat("%08" PRIu64, id_);
}
size_t DebugChunkSource::GetChunkCount() const {
if (!cached_chunk_count_.has_value()) {
std::mt19937_64 rng(id_);
const double stddev = std::max(1.0, mean_chunk_count_ / 4.0);
std::normal_distribution<double> distribution(mean_chunk_count_, stddev);
const double sampled = distribution(rng);
const auto rounded =
static_cast<long long>(std::llround(std::max(sampled, 1.0)));
cached_chunk_count_ = static_cast<size_t>(rounded);
}
return *cached_chunk_count_;
}
std::optional<std::vector<FrameType>> DebugChunkSource::GetChunkData(
size_t index) {
const auto seed_pair = std::make_pair(id_, index);
const uint64_t seed = static_cast<uint64_t>(
absl::Hash<std::pair<uint64_t, size_t>>{}(seed_pair));
std::mt19937_64 rng(seed);
std::uniform_int_distribution<int> frame_count_distribution(1, 200);
const int frame_count = frame_count_distribution(rng);
std::vector<FrameType> result(frame_count);
for (int frame_index = 0; frame_index < frame_count; ++frame_index) {
result[frame_index] = FrameType{};
result[frame_index].planes[0] = static_cast<uint64_t>(id_);
result[frame_index].planes[1] = static_cast<uint64_t>(index);
result[frame_index].planes[2] = static_cast<uint64_t>(frame_index);
}
return result;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/debug_chunk_source.h
================================================
#pragma once
#include <cstddef>
#include <cstdint>
#include <optional>
#include <random>
#include <string>
#include "loader/chunk_source/chunk_source.h"
namespace lczero {
namespace training {
// DebugChunkSource synthesizes deterministic pseudo-random chunks for loader
// debugging. Each instance is identified by an integer id. The class produces
// a chunk count sampled from a normal distribution with the provided mean and
// mean / 4 standard deviation. The id serves as the seed, which keeps the
// number of chunks stable across runs. Individual chunks contain a
// pseudo-random number of FrameType frames (between one and 200) that are
// generated on demand. The generation seed depends on both the source id and
// chunk index. This lets shuffling logic exercise variable chunk sizes while
// keeping the content reproducible. Each generated frame is zero-initialized,
// but the first three entries of the planes array encode, respectively, the
// source id, the chunk index, and the frame index within the chunk. This makes
// it easy to reason about ordering and grouping when inspecting chunk payloads.
class DebugChunkSource : public ChunkSource {
public:
DebugChunkSource(uint64_t id, double mean_chunk_count);
private:
std::string GetChunkSortKey() const override;
size_t GetChunkCount() const override;
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;
uint64_t id_;
double mean_chunk_count_;
mutable std::optional<size_t> cached_chunk_count_;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/rawfile_chunk_source.cc
================================================
#include "loader/chunk_source/rawfile_chunk_source.h"
#include <absl/log/log.h>
#include <fstream>
#include <stdexcept>
#include "trainingdata/trainingdata_v6.h"
#include "utils/files.h"
#include "utils/gz.h"
namespace lczero {
namespace training {
RawFileChunkSource::RawFileChunkSource(
const std::filesystem::path& filename,
ChunkSourceLoaderConfig::FrameFormat frame_format)
: filename_(filename), frame_format_(frame_format) {}
RawFileChunkSource::~RawFileChunkSource() = default;
std::string RawFileChunkSource::GetChunkSortKey() const {
return std::filesystem::path(filename_).filename().string();
}
size_t RawFileChunkSource::GetChunkCount() const { return 1; }
std::optional<std::vector<FrameType>> RawFileChunkSource::GetChunkData(
size_t index) {
if (index != 0) return std::nullopt;
std::string data = ReadFileToString(filename_);
if (data.empty()) return std::nullopt;
const size_t input_size =
frame_format_ == ChunkSourceLoaderConfig::V7TrainingData
? sizeof(V7TrainingData)
: sizeof(V6TrainingData);
if (data.size() % input_size != 0) {
LOG(WARNING) << "File " << filename_ << " size " << data.size()
<< " is not a multiple of input frame size " << input_size;
return std::nullopt;
}
const size_t num_frames = data.size() / input_size;
std::vector<V7TrainingData> result(num_frames);
if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) {
std::memcpy(result.data(), data.data(), data.size());
} else {
const auto* v6_data = reinterpret_cast<const V6TrainingData*>(data.data());
for (size_t i = 0; i < num_frames; ++i) {
std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData));
}
}
return result;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/rawfile_chunk_source.h
================================================
#pragma once
#include <filesystem>
#include <string>
#include "loader/chunk_source/chunk_source.h"
#include "proto/data_loader_config.pb.h"
namespace lczero {
namespace training {
// A chunk source that reads a single (potentially gzipped) file as a single
// chunk.
class RawFileChunkSource : public ChunkSource {
public:
RawFileChunkSource(const std::filesystem::path& filename,
ChunkSourceLoaderConfig::FrameFormat frame_format);
~RawFileChunkSource();
private:
std::string GetChunkSortKey() const override;
size_t GetChunkCount() const override;
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;
std::string filename_;
ChunkSourceLoaderConfig::FrameFormat frame_format_;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/tar_chunk_source.cc
================================================
#include "loader/chunk_source/tar_chunk_source.h"
#include <absl/log/log.h>
#include <absl/strings/str_cat.h>
#include <zlib.h>
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <fcntl.h>
#include <stdexcept>
#include <unistd.h>
#include "trainingdata/trainingdata_v6.h"
#include "utils/gz.h"
namespace lczero {
namespace training {
namespace {
struct TarHeader {
std::array<char, 100> name;
std::array<uint8_t, 8> mode;
std::array<uint8_t, 8> uid;
std::array<uint8_t, 8> gid;
std::array<uint8_t, 12> size;
std::array<uint8_t, 12> mtime;
std::array<uint8_t, 8> chksum;
uint8_t typeflag;
std::array<uint8_t, 100> linkname;
std::array<uint8_t, 6> magic;
std::array<uint8_t, 2> version;
std::array<uint8_t, 32> uname;
std::array<uint8_t, 32> gname;
std::array<uint8_t, 8> devmajor;
std::array<uint8_t, 8> devminor;
std::array<uint8_t, 155> prefix;
std::array<uint8_t, 12> padding;
};
static_assert(sizeof(TarHeader) == 512, "TarHeader must be exactly 512 bytes");
uint64_t ParseOctal(const std::array<uint8_t, 12>& octal) {
uint64_t value = 0;
for (uint8_t digit : octal) {
if (!digit) break;
value = (value << 3) + (digit - '0');
}
return value;
}
bool ReadExact(int fd, off_t offset, void* buffer, size_t size) {
char* out = static_cast<char*>(buffer);
size_t read_total = 0;
while (read_total < size) {
const ssize_t read_now =
pread(fd, out + read_total, size - read_total, offset + read_total);
if (read_now <= 0) return false;
read_total += static_cast<size_t>(read_now);
}
return true;
}
std::optional<std::string> ReadGzipPrefix(int fd, off_t offset, size_t size,
size_t max_bytes) {
if (max_bytes == 0) return std::string();
z_stream strm = {};
if (inflateInit2(&strm, 16 + MAX_WBITS) != Z_OK) {
return std::nullopt;
}
constexpr size_t kChunkSize = 16384;
std::array<uint8_t, kChunkSize> input_buffer;
std::array<char, kChunkSize> output_buffer;
std::string output;
output.reserve(std::min<size_t>(max_bytes, kChunkSize));
size_t remaining = size;
off_t current_offset = offset;
bool finished = false;
while (remaining > 0 && !finished && output.size() < max_bytes) {
const size_t to_read = std::min(remaining, kChunkSize);
if (!ReadExact(fd, current_offset, input_buffer.data(), to_read)) {
inflateEnd(&strm);
return std::nullopt;
}
remaining -= to_read;
current_offset += static_cast<off_t>(to_read);
strm.next_in = reinterpret_cast<Bytef*>(input_buffer.data());
strm.avail_in = static_cast<uInt>(to_read);
while (strm.avail_in > 0 && output.size() < max_bytes) {
strm.next_out = reinterpret_cast<Bytef*>(output_buffer.data());
strm.avail_out = kChunkSize;
const int ret = inflate(&strm, Z_NO_FLUSH);
if (ret == Z_STREAM_ERROR || ret == Z_NEED_DICT || ret == Z_DATA_ERROR ||
ret == Z_MEM_ERROR) {
inflateEnd(&strm);
return std::nullopt;
}
const size_t produced = kChunkSize - strm.avail_out;
const size_t to_copy = std::min(produced, max_bytes - output.size());
output.append(output_buffer.data(), to_copy);
if (ret == Z_STREAM_END) {
finished = true;
break;
}
}
}
inflateEnd(&strm);
return output;
}
} // namespace
TarChunkSource::TarChunkSource(
const std::filesystem::path& filename,
ChunkSourceLoaderConfig::FrameFormat frame_format)
: path_(filename),
filename_(filename.filename().string()),
frame_format_(frame_format) {
fd_ = open(path_.c_str(), O_RDONLY | O_CLOEXEC);
if (fd_ < 0) {
throw std::runtime_error(
absl::StrCat("Failed to open tar file: ", path_.string(), ": ",
std::strerror(errno)));
}
// Perform indexing during construction.
Index();
}
TarChunkSource::~TarChunkSource() { Close(); }
void TarChunkSource::Close() {
if (fd_ >= 0 && close(fd_) != 0) {
PLOG(WARNING) << "Failed to close tar file descriptor for " << path_;
}
fd_ = -1;
}
std::string TarChunkSource::GetChunkSortKey() const { return filename_; }
void TarChunkSource::Index() {
assert(files_.empty());
off_t offset = 0;
while (true) {
TarHeader header;
if (!ReadExact(fd_, offset, &header, sizeof(header))) {
LOG(WARNING) << "Truncated tar file: " << filename_;
break;
}
offset += sizeof(header);
if (header.name[0] == '\0') break; // End of file.
switch (header.typeflag) {
case '5': // Directory
continue;
case '0': // Regular file
break;
default:
LOG(WARNING) << "Unsupported tar header type: " << header.typeflag;
continue;
}
std::string_view fname(const_cast<const char*>(header.name.data()));
const std::filesystem::path filepath = std::filesystem::path(fname);
const off_t file_offset = offset;
const size_t size = ParseOctal(header.size);
const size_t padded_size = ((size + 511) / 512) * 512;
offset = file_offset + static_cast<off_t>(padded_size);
if (filepath.filename() == "LICENSE") continue;
files_.push_back({file_offset, size, filepath.extension() == ".gz"});
}
LOG(INFO) << "Read " << files_.size() << " entries from " << filename_;
}
size_t TarChunkSource::GetChunkCount() const { return files_.size(); }
std::optional<std::vector<FrameType>> TarChunkSource::GetChunkData(
size_t index) {
if (index >= files_.size()) {
throw std::out_of_range("File index out of range");
}
const auto& file_entry = files_[index];
std::string content(file_entry.size, '\0');
if (!ReadExact(fd_, file_entry.offset, content.data(), file_entry.size)) {
return std::nullopt;
}
if (file_entry.is_gzip) {
try {
content = GunzipBuffer(content);
} catch (const GunzipError& e) {
return std::nullopt;
}
}
if (content.empty()) return std::nullopt;
const size_t input_size =
frame_format_ == ChunkSourceLoaderConfig::V7TrainingData
? sizeof(V7TrainingData)
: sizeof(V6TrainingData);
if (content.size() % input_size != 0) {
LOG(WARNING) << "Chunk " << index << " from " << filename_ << " size "
<< content.size() << " is not a multiple of input frame size "
<< input_size;
return std::nullopt;
}
const size_t num_frames = content.size() / input_size;
std::vector<V7TrainingData> result(num_frames);
if (frame_format_ == ChunkSourceLoaderConfig::V7TrainingData) {
std::memcpy(result.data(), content.data(), content.size());
} else {
const auto* v6_data =
reinterpret_cast<const V6TrainingData*>(content.data());
for (size_t i = 0; i < num_frames; ++i) {
std::memcpy(&result[i], &v6_data[i], sizeof(V6TrainingData));
}
}
return result;
}
std::optional<std::string> TarChunkSource::GetChunkPrefix(size_t index,
size_t max_bytes) {
if (index >= files_.size()) {
throw std::out_of_range("File index out of range");
}
const auto& file_entry = files_[index];
if (file_entry.is_gzip) {
return ReadGzipPrefix(fd_, file_entry.offset, file_entry.size, max_bytes);
}
const size_t to_read = std::min(file_entry.size, max_bytes);
std::string content(to_read, '\0');
if (!ReadExact(fd_, file_entry.offset, content.data(), to_read)) {
return std::nullopt;
}
return content;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/chunk_source/tar_chunk_source.h
================================================
#pragma once
#include <sys/types.h>
#include <filesystem>
#include <string>
#include <vector>
#include "loader/chunk_source/chunk_source.h"
#include "proto/data_loader_config.pb.h"
namespace lczero {
namespace training {
// A chunk source that reads a tar archive and provides access to its files as
// chunks. Each file in the tar is treated as a separate chunk.
class TarChunkSource : public ChunkSource {
public:
TarChunkSource(const std::filesystem::path& filename,
ChunkSourceLoaderConfig::FrameFormat frame_format);
~TarChunkSource() override;
std::string GetChunkSortKey() const override;
size_t GetChunkCount() const override;
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override;
std::optional<std::string> GetChunkPrefix(size_t index, size_t max_bytes);
private:
// Performs one-time indexing during construction. Not part of the interface.
void Index();
struct FileEntry {
off_t offset;
size_t size;
bool is_gzip;
};
void Close();
int fd_ = -1;
std::vector<FileEntry> files_;
std::filesystem::path path_;
std::string filename_;
ChunkSourceLoaderConfig::FrameFormat frame_format_;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/data_loader.cc
================================================
#include "loader/data_loader.h"
#include <absl/algorithm/container.h>
#include <absl/log/log.h>
#include <absl/strings/str_cat.h>
#include <absl/strings/str_join.h>
#include <chrono>
#include <cmath>
#include <optional>
#include <stdexcept>
#include "loader/data_loader_metrics.h"
namespace lczero {
namespace training {
DataLoaderConfig DataLoader::ParseConfig(const std::string& serialized_config) {
DataLoaderConfig config;
config.ParseFromString(serialized_config);
return config;
}
DataLoader::DataLoader(const std::string& serialized_data_loader_config)
: metrics_aggregator_(
[](DataLoaderMetricsProto& m) { m.Clear(); },
[](DataLoaderMetricsProto& dest, const DataLoaderMetricsProto& src) {
UpdateFrom(dest, src);
}) {
DataLoaderConfig config = ParseConfig(serialized_data_loader_config);
AddStages(config);
BuildOutputMapping(config);
LOG(INFO) << "DataLoader initialized with " << stage_registry_.size()
<< " stage(s) and " << outputs_.size() << " output(s).";
}
DataLoader::~DataLoader() { Stop(); }
void DataLoader::AddStages(const std::string& serialized_data_loader_config) {
AddStages(ParseConfig(serialized_data_loader_config));
}
void DataLoader::AddStages(const DataLoaderConfig& config) {
for (const auto& stage_config : config.stage()) AddStage(stage_config);
for (const auto& stage_config : config.stage()) SetStageInputs(stage_config);
}
void DataLoader::AddStage(const StageConfig& stage_config) {
if (started_) {
throw std::runtime_error("Cannot add stages after DataLoader has started.");
}
auto stage = CreateStage(stage_config);
if (!stage_config.has_name()) {
throw std::runtime_error("Stage configuration is missing name.");
}
LOG(INFO) << "Adding stage '" << stage_config.name() << "'.";
stage_registry_.AddStage(stage_config.name(), std::move(stage));
}
void DataLoader::SetStageInputs(const StageConfig& stage_config) {
// Resolve input names to queue pointers.
std::vector<QueueBase*> input_queues;
input_queues.reserve(stage_config.input_size());
for (const auto& input_name : stage_config.input()) {
QueueBase* queue = stage_registry_.GetStageOutput(input_name);
if (!queue) {
throw std::runtime_error(absl::StrCat("Input stage '", input_name,
"' not found for stage '",
stage_config.name(), "'."));
}
input_queues.push_back(queue);
}
// Wire up inputs.
auto it = absl::c_find_if(stage_registry_.stages(), [&](const auto& p) {
return p.first == stage_config.name();
});
if (it == stage_registry_.stages().end()) {
throw std::runtime_error(absl::StrCat("Stage '", stage_config.name(),
"' not found in registry."));
}
it->second->SetInputs(absl::MakeSpan(input_queues));
}
void DataLoader::Start() {
if (started_) {
throw std::runtime_error("DataLoader has already been started.");
}
for (auto& [name, stage] : stage_registry_.stages()) {
LOG(INFO) << "Starting stage '" << name << "'.";
stage->Start();
}
metrics_thread_ = std::jthread(
[this](std::stop_token stop_token) { MetricsThread(stop_token); });
started_ = true;
stopped_ = false;
LOG(INFO) << "DataLoader started.";
}
TensorTuple DataLoader::GetNext(std::string_view alias) {
Queue<TensorTuple>* q = GetOutputQueue(alias);
if (!q) {
std::string alias_list = absl::StrJoin(
outputs_, ", ",
[](std::string* out, const auto& p) { absl::StrAppend(out, p.first); });
throw std::runtime_error(absl::StrCat("Unknown DataLoader output: '", alias,
"'. Available outputs: [", alias_list,
"]."));
}
return q->Get();
}
std::optional<TensorTuple> DataLoader::MaybeGetNext(std::string_view alias) {
Queue<TensorTuple>* q = GetOutputQueue(alias);
if (!q) {
std::string alias_list = absl::StrJoin(
outputs_, ", ",
[](std::string* out, const auto& p) { absl::StrAppend(out, p.first); });
throw std::runtime_error(absl::StrCat("Unknown DataLoader output: '", alias,
"'. Available outputs: [", alias_list,
"]."));
}
return q->MaybeGet();
}
void DataLoader::Stop() {
if (stopped_) return;
LOG(INFO) << "Stopping DataLoader.";
if (metrics_thread_.joinable()) {
metrics_thread_.request_stop();
metrics_thread_.join();
}
for (auto& [name, stage] : stage_registry_.stages()) {
LOG(INFO) << "Stopping stage '" << name << "'.";
stage->Stop();
}
stopped_ = true;
started_ = false;
LOG(INFO) << "DataLoader stopped.";
}
std::pair<std::string, float> DataLoader::GetBucketMetrics(
int time_period, bool include_pending) const {
auto [metrics, duration] = metrics_aggregator_.GetBucketMetrics(
static_cast<TimePeriod>(time_period),
include_pending ? std::make_optional(std::chrono::steady_clock::now())
: std::nullopt);
float duration_seconds = std::chrono::duration<float>(duration).count();
return {metrics.OutputAsString(), duration_seconds};
}
std::pair<std::string, float> DataLoader::GetAggregateEndingNow(
float duration_seconds, bool include_pending) const {
std::chrono::nanoseconds duration_ns =
std::isinf(duration_seconds)
? std::chrono::nanoseconds::max()
: std::chrono::nanoseconds(
static_cast<int64_t>(duration_seconds * 1e9));
auto [metrics, duration] = metrics_aggregator_.GetAggregateEndingNow(
duration_ns, include_pending
? std::make_optional(std::chrono::steady_clock::now())
: std::nullopt);
float result_duration_seconds =
std::chrono::duration<float>(duration).count();
return {metrics.OutputAsString(), result_duration_seconds};
}
void DataLoader::MetricsThread(std::stop_token stop_token) {
while (!stop_token.stop_requested()) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
DataLoaderMetricsProto metrics;
for (auto& [name, stage] : stage_registry_.stages()) {
StageMetricProto stage_metric = stage->FlushMetrics();
stage_metric.set_name(name);
*metrics.add_stage_metrics() = std::move(stage_metric);
}
metrics_aggregator_.RecordMetrics(std::move(metrics));
metrics_aggregator_.Advance(std::chrono::steady_clock::now());
}
LOG(INFO) << "Metrics thread stopping.";
}
std::vector<std::pair<std::string, StageControlResponse>>
DataLoader::SendControlMessage(const StageControlRequest& request) {
std::vector<std::pair<std::string, StageControlResponse>> responses;
responses.reserve(stage_registry_.size());
for (auto& [name, stage] : stage_registry_.stages()) {
std::optional<StageControlResponse> response = stage->Control(request);
if (response.has_value()) {
responses.emplace_back(name, std::move(*response));
}
}
return responses;
}
void DataLoader::BuildOutputMapping(const DataLoaderConfig& config) {
outputs_.clear();
outputs_.reserve(config.output_size());
for (const auto& out_spec : config.output()) {
auto it = absl::c_find(out_spec, ':');
std::string alias = it == out_spec.end()
? std::string("")
: std::string(out_spec.begin(), it);
std::string stage_name = it == out_spec.end()
? std::string(out_spec)
: std::string(it + 1, out_spec.end());
// Ensure alias is unique.
if (absl::c_find_if(outputs_, [&](const auto& p) {
return p.first == alias;
}) != outputs_.end()) {
throw std::runtime_error(
absl::StrCat("Duplicate output alias specified: ", alias));
}
Queue<TensorTuple>* queue =
stage_registry_.GetTypedStageOutput<TensorTuple>(stage_name);
if (queue == nullptr) {
throw std::runtime_error(
absl::StrCat("Output stage not found or wrong type: ", stage_name));
}
outputs_.emplace_back(std::move(alias), queue);
}
}
Queue<TensorTuple>* DataLoader::GetOutputQueue(std::string_view alias) const {
auto it = absl::c_find_if(outputs_,
[&](const auto& p) { return p.first == alias; });
if (it == outputs_.end()) return nullptr;
return it->second;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/data_loader.h
================================================
#pragma once
#include <memory>
#include <string>
#include <string_view>
#include <thread>
#include <utility>
#include <vector>
#include "loader/stages/stage_factory.h"
#include "proto/data_loader_config.pb.h"
#include "proto/stage_control.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/exponential_aggregator.h"
#include "utils/queue.h"
#include "utils/tensor.h"
namespace lczero {
namespace training {
using TensorDict =
std::vector<std::pair<std::string, std::unique_ptr<TensorBase>>>;
class DataLoader {
public:
using MetricsAggregator = ExponentialAggregator<DataLoaderMetricsProto,
TimePeriod::k250Milliseconds>;
explicit DataLoader(const std::string& serialized_data_loader_config);
~DataLoader();
void Start();
TensorTuple GetNext(std::string_view alias);
std::optional<TensorTuple> MaybeGetNext(std::string_view alias);
void Stop();
std::pair<std::string, float> GetBucketMetrics(int time_period,
bool include_pending) const;
std::pair<std::string, float> GetAggregateEndingNow(
float duration_seconds, bool include_pending) const;
void AddStages(const DataLoaderConfig& config);
void AddStages(const std::string& serialized_data_loader_config);
std::vector<std::pair<std::string, StageControlResponse>> SendControlMessage(
const StageControlRequest& request);
private:
void AddStage(const StageConfig& stage_config);
void SetStageInputs(const StageConfig& stage_config);
static DataLoaderConfig ParseConfig(
const std::string& serialized_data_loader_config);
void MetricsThread(std::stop_token stop_token);
void BuildOutputMapping(const DataLoaderConfig& config);
Queue<TensorTuple>* GetOutputQueue(std::string_view alias) const;
StageRegistry stage_registry_;
std::vector<std::pair<std::string, Queue<TensorTuple>*>> outputs_;
MetricsAggregator metrics_aggregator_;
std::jthread metrics_thread_;
bool started_ = false;
bool stopped_ = false;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/data_loader_metrics.cc
================================================
// ABOUTME: Implementation of UpdateFrom functions for data loader metric
// protobuf messages. ABOUTME: Handles aggregation of generic stage metrics.
#include "loader/data_loader_metrics.h"
#include <algorithm>
#include "absl/strings/string_view.h"
#include "utils/metrics/statistics_metric.h"
namespace lczero {
namespace training {
namespace {
template <typename ProtoT>
ProtoT* FindByName(std::vector<ProtoT>* entries, absl::string_view name) {
for (auto& entry : *entries) {
if (entry.name() == name) return &entry;
}
return nullptr;
}
} // namespace
void UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src) {
if (src.has_name()) dest.set_name(src.name());
dest.set_put_count(dest.put_count() + src.put_count());
dest.set_get_count(dest.get_count() + src.get_count());
dest.set_drop_count(dest.drop_count() + src.drop_count());
UpdateFrom(*dest.mutable_queue_fullness(), src.queue_fullness());
if (src.has_queue_capacity()) dest.set_queue_capacity(src.queue_capacity());
}
void UpdateFrom(CountMetricProto& dest, const CountMetricProto& src) {
if (src.has_name()) dest.set_name(src.name());
if (src.has_count()) dest.set_count(dest.count() + src.count());
}
void UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src) {
if (src.has_name()) dest.set_name(src.name());
if (src.has_value()) dest.set_value(src.value());
if (src.has_capacity()) dest.set_capacity(src.capacity());
}
void UpdateFrom(StageMetricProto& dest, const StageMetricProto& src) {
if (src.has_name()) dest.set_name(src.name());
for (const auto& load_metrics : src.load_metrics()) {
LoadMetricProto* dest_load =
load_metrics.has_name()
? FindByName(dest.mutable_load_metrics(), load_metrics.name())
: nullptr;
if (dest_load == nullptr) {
dest_load = dest.add_load_metrics();
}
UpdateFrom(*dest_load, load_metrics);
}
for (const auto& queue_metrics : src.queue_metrics()) {
QueueMetricProto* dest_queue =
queue_metrics.has_name()
? FindByName(dest.mutable_queue_metrics(), queue_metrics.name())
: nullptr;
if (dest_queue == nullptr) {
dest_queue = dest.add_queue_metrics();
}
UpdateFrom(*dest_queue, queue_metrics);
}
for (const auto& count_metrics : src.count_metrics()) {
CountMetricProto* dest_count =
count_metrics.has_name()
? FindByName(dest.mutable_count_metrics(), count_metrics.name())
: nullptr;
if (dest_count == nullptr) {
dest_count = dest.add_count_metrics();
}
UpdateFrom(*dest_count, count_metrics);
}
for (const auto& gauge_metrics : src.gauge_metrics()) {
GaugeMetricProto* dest_gauge =
gauge_metrics.has_name()
? FindByName(dest.mutable_gauge_metrics(), gauge_metrics.name())
: nullptr;
if (dest_gauge == nullptr) {
dest_gauge = dest.add_gauge_metrics();
}
UpdateFrom(*dest_gauge, gauge_metrics);
}
for (const auto& statistics_metrics : src.statistics_metrics()) {
StatisticsProtoDouble* dest_stats =
statistics_metrics.has_name()
? FindByName(dest.mutable_statistics_metrics(),
statistics_metrics.name())
: nullptr;
if (dest_stats == nullptr) {
dest_stats = dest.add_statistics_metrics();
}
UpdateFrom(*dest_stats, statistics_metrics);
}
if (src.has_last_chunk_key()) dest.set_last_chunk_key(src.last_chunk_key());
if (src.has_anchor()) dest.set_anchor(src.anchor());
}
void UpdateFrom(DataLoaderMetricsProto& dest,
const DataLoaderMetricsProto& src) {
for (const auto& stage_metrics : src.stage_metrics()) {
StageMetricProto* dest_stage =
stage_metrics.has_name()
? FindByName(dest.mutable_stage_metrics(), stage_metrics.name())
: nullptr;
if (dest_stage == nullptr) {
dest_stage = dest.add_stage_metrics();
}
UpdateFrom(*dest_stage, stage_metrics);
}
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/data_loader_metrics.h
================================================
// ABOUTME: Header for UpdateFrom functions for data loader metric protobuf
// messages. ABOUTME: Declares functions for aggregating FilePathProvider and
// DataLoader metrics.
#pragma once
#include "absl/strings/string_view.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
#include "utils/metrics/statistics_metric.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
void UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src);
void UpdateFrom(CountMetricProto& dest, const CountMetricProto& src);
void UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src);
void UpdateFrom(StageMetricProto& dest, const StageMetricProto& src);
void UpdateFrom(DataLoaderMetricsProto& dest,
const DataLoaderMetricsProto& src);
template <typename T>
QueueMetricProto MetricsFromQueue(absl::string_view name, Queue<T>& queue) {
QueueMetricProto result;
result.set_name(std::string(name));
result.set_put_count(queue.GetTotalPutCount(true));
result.set_get_count(queue.GetTotalGetCount(true));
result.set_drop_count(queue.GetTotalDropCount(true));
AddSample(*result.mutable_queue_fullness(), queue.Size());
result.set_queue_capacity(queue.Capacity());
return result;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/data_loader_test.cc
================================================
#include "loader/data_loader.h"
#include <gtest/gtest.h>
#include <stdexcept>
namespace lczero {
namespace training {
TEST(DataLoaderTest, AllowsNoOutputsConfigured) {
DataLoaderConfig config;
auto* file_stage = config.add_stage();
file_stage->set_name("file_path_provider");
file_stage->mutable_file_path_provider()->set_directory(".");
EXPECT_NO_THROW(DataLoader(config.OutputAsString()));
}
TEST(DataLoaderTest, ThrowsOnDuplicateStageName) {
DataLoaderConfig config;
auto* first_stage = config.add_stage();
first_stage->set_name("duplicate");
first_stage->mutable_file_path_provider()->set_directory(".");
auto* second_stage = config.add_stage();
second_stage->set_name("duplicate");
second_stage->mutable_file_path_provider()->set_directory(".");
EXPECT_THROW(DataLoader(config.OutputAsString()), std::runtime_error);
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/frame_type.h
================================================
/*
This file is part of Leela Chess Zero.
Copyright (C) 2025 The LCZero Authors
Leela Chess is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Leela Chess is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Leela Chess. If not, see <http://www.gnu.org/licenses/>.
Additional permission under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
modified version of those libraries), containing parts covered by the
terms of the respective license agreement, the licensors of this
Program grant you additional permission to convey the resulting work.
*/
#pragma once
#include "trainingdata/trainingdata_v7.h"
namespace lczero {
namespace training {
using FrameType = V7TrainingData;
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/loader_main.cpp
================================================
#include <absl/flags/flag.h>
#include <absl/flags/parse.h>
#include <absl/log/globals.h>
#include <absl/log/initialize.h>
#include <absl/log/log.h>
#include <absl/strings/str_cat.h>
#include <absl/strings/str_format.h>
#include <absl/time/clock.h>
#include <absl/time/time.h>
#include <chrono>
#include <iostream>
#include "data_loader.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/printer.h"
ABSL_FLAG(std::string, directory, "/home/crem/tmp/2025-07/lczero-training/",
"Directory to watch for training data files");
ABSL_FLAG(size_t, chunk_pool_size, 1000000, "Size of the chunk shuffle buffer");
ABSL_FLAG(size_t, reservoir_size_per_thread, 1000000,
"Size of the reservoir for frame sampling per thread");
namespace lczero {
namespace training {
void Run() {
DataLoaderConfig config;
// Configure file path provider stage.
auto* file_stage = config.add_stage();
file_stage->set_name("file_path_provider");
auto* file_path_provider = file_stage->mutable_file_path_provider();
file_path_provider->set_directory(absl::GetFlag(FLAGS_directory));
// Configure chunk source loader stage.
auto* chunk_loader_stage = config.add_stage();
chunk_loader_stage->set_name("chunk_source_loader");
chunk_loader_stage->add_input(file_stage->name());
chunk_loader_stage->mutable_chunk_source_loader();
// Configure shuffling chunk pool stage.
auto* chunk_pool_stage = config.add_stage();
chunk_pool_stage->set_name("shuffling_chunk_pool");
chunk_pool_stage->add_input(chunk_loader_stage->name());
auto* shuffling_chunk_pool = chunk_pool_stage->mutable_shuffling_chunk_pool();
shuffling_chunk_pool->set_chunk_pool_size(
absl::GetFlag(FLAGS_chunk_pool_size));
// Configure chunk unpacker stage.
auto* unpacker_stage = config.add_stage();
unpacker_stage->set_name("chunk_unpacker");
unpacker_stage->add_input(chunk_pool_stage->name());
unpacker_stage->mutable_chunk_unpacker();
// Configure shuffling frame sampler stage.
auto* sampler_stage = config.add_stage();
sampler_stage->set_name("shuffling_frame_sampler");
sampler_stage->add_input(unpacker_stage->name());
auto* shuffling_frame_sampler =
sampler_stage->mutable_shuffling_frame_sampler();
shuffling_frame_sampler->set_reservoir_size_per_thread(
absl::GetFlag(FLAGS_reservoir_size_per_thread));
// Configure tensor generator stage.
auto* tensor_stage = config.add_stage();
tensor_stage->set_name("tensor_generator");
tensor_stage->add_input(sampler_stage->name());
tensor_stage->mutable_tensor_generator();
// Serialize config and create loader
config.add_output("tensor_generator");
std::string config_string = config.OutputAsString();
DataLoader loader(config_string);
return;
std::atomic<size_t> batch_count = 0;
auto start_time = absl::Now();
// Start logging thread
std::atomic<bool> should_stop{false};
std::thread logging_thread([&]() {
while (!should_stop.load()) {
std::this_thread::sleep_for(std::chrono::seconds(1));
auto current_time = absl::Now();
auto total_elapsed = current_time - start_time;
double rate = batch_count / absl::ToDoubleSeconds(total_elapsed);
auto [stats_string, duration] =
loader.GetBucketMetrics(0, false); // k1Second = 0
DataLoaderMetricsProto metrics;
metrics.ParseFromString(stats_string);
std::string metrics_json = metrics.OutputAsJson();
LOG(INFO) << absl::StrCat("Processed ", batch_count.load(),
" batches in ",
absl::ToDoubleSeconds(total_elapsed),
"s. Rate: ", absl::StrFormat("%.2f", rate),
" batches/sec. ", "Metrics: ", metrics_json);
}
});
while (true) {
TensorTuple batch = loader.GetNext("train");
++batch_count;
}
}
} // namespace training
} // namespace lczero
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
absl::InitializeLog();
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
lczero::training::Run();
return 0;
}
================================================
FILE: csrc/loader/pybind_module.cc
================================================
// ABOUTME: PyBind11 binding module exposing C++ DataLoader to Python.
// ABOUTME: Handles configuration conversion and tensor memory management for
// numpy arrays.
#include <absl/log/globals.h>
#include <absl/log/initialize.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl/filesystem.h>
#include <pybind11/stl_bind.h>
#include <stdexcept>
#include <string>
#include "loader/data_loader.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/chunk_unpacker.h"
#include "loader/stages/file_path_provider.h"
#include "loader/stages/shuffling_chunk_pool.h"
#include "loader/stages/shuffling_frame_sampler.h"
#include "loader/stages/tensor_generator.h"
#include "utils/tensor.h"
namespace py = pybind11;
namespace lczero {
namespace training {
namespace {
std::string SerializePyProto(const py::handle& obj, const char* expected_type) {
if (py::isinstance<py::bytes>(obj)) {
return obj.cast<std::string>();
}
if (py::hasattr(obj, "SerializeToString")) {
py::object bytes_obj = obj.attr("SerializeToString")();
return bytes_obj.cast<py::bytes>().cast<std::string>();
}
throw std::invalid_argument(std::string("Expected ") + expected_type +
" protobuf message or bytes.");
}
template <typename ProtoT>
ProtoT ParsePyProto(const py::handle& obj, const char* expected_type) {
ProtoT proto;
proto.ParseFromString(SerializePyProto(obj, expected_type));
return proto;
}
py::object MakePythonProto(const char* module_name, const char* message_name,
const std::string& serialized) {
py::object message_cls = py::module::import(module_name).attr(message_name);
py::object message_obj = message_cls();
message_obj.attr("ParseFromString")(py::bytes(serialized));
return message_obj;
}
} // namespace
// Helper function to convert TensorBase to numpy array using buffer protocol.
py::array tensor_to_numpy(std::unique_ptr<TensorBase> tensor) {
// Extract raw pointer and release ownership from unique_ptr.
TensorBase* raw_tensor = tensor.release();
// Create numpy array with take_ownership policy.
// This transfers memory ownership to Python/numpy.
return py::array(
py::dtype(raw_tensor->py_format()), raw_tensor->shape(),
raw_tensor->strides(), raw_tensor->data(),
py::cast(raw_tensor, py::return_value_policy::take_ownership));
}
// Convert TensorTuple to tuple of numpy arrays.
py::tuple tensor_tuple_to_numpy_tuple(TensorTuple tensor_tuple) {
py::tuple result(tensor_tuple.size());
for (size_t i = 0; i < tensor_tuple.size(); ++i) {
result[i] = tensor_to_numpy(std::move(tensor_tuple[i]));
}
return result;
}
PYBIND11_MODULE(_lczero_training, m) {
absl::InitializeLog();
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
m.doc() = "Leela Chess Zero training data loader";
// Configuration is now handled via protobuf serialized strings
// Expose the main DataLoader class.
py::class_<DataLoader>(m, "DataLoader")
.def(py::init([](py::object config) {
std::string config_string =
SerializePyProto(config, "DataLoaderConfig");
py::gil_scoped_release release;
return new DataLoader(config_string);
}),
py::arg("config"),
"Create DataLoader from DataLoaderConfig proto or bytes.")
.def(
"add_stages",
[](DataLoader& self, py::object config) {
std::string config_string =
SerializePyProto(config, "DataLoaderConfig");
py::gil_scoped_release release;
self.AddStages(config_string);
},
py::arg("config"),
"Append stages from DataLoaderConfig proto or bytes.")
.def(
"send_control_message",
[](DataLoader& self, py::object request) {
StageControlRequest control_request =
ParsePyProto<StageControlRequest>(request,
"StageControlRequest");
auto responses = [&]() {
py::gil_scoped_release release;
return self.SendControlMessage(control_request);
}();
py::list result;
for (const auto& [stage_name, response] : responses) {
py::object response_obj = MakePythonProto(
"proto.stage_control_pb2", "StageControlResponse",
response.OutputAsString());
result.append(py::make_tuple(stage_name, response_obj));
}
return result;
},
py::arg("request"),
"Send StageControlRequest to stages and return (stage_name, "
"StageControlResponse) tuples.")
.def(
"get_next",
[](DataLoader& self, const std::string& alias) {
return tensor_tuple_to_numpy_tuple([&] {
py::gil_scoped_release release;
return self.GetNext(alias);
}());
},
py::arg("alias") = "",
"Get next batch for the given output alias (default empty) as a "
"tuple of numpy arrays")
.def(
"maybe_get_next",
[](DataLoader& self,
const std::string& alias) -> std::optional<py::tuple> {
auto result = [&] {
py::gil_scoped_release release;
return self.MaybeGetNext(alias);
}();
if (result.has_value()) {
return tensor_tuple_to_numpy_tuple(std::move(*result));
}
return std::nullopt;
},
py::arg("alias") = "",
"Non-blocking get next batch for the given output alias (default "
"empty). Returns tuple of numpy arrays or None if no data available")
.def(
"get_bucket_metrics",
[](const DataLoader& self, int time_period, bool include_pending) {
auto [metrics, duration] = [&] {
py::gil_scoped_release release;
return self.GetBucketMetrics(time_period, include_pending);
}();
return py::make_tuple(py::bytes(metrics), duration);
},
"Get serialized metrics for bucket and duration as (bytes, float)")
.def(
"get_aggregate_ending_now",
[](const DataLoader& self, float duration_seconds,
bool include_pending) {
auto [metrics, duration] = [&] {
py::gil_scoped_release release;
return self.GetAggregateEndingNow(duration_seconds,
include_pending);
}();
return py::make_tuple(py::bytes(metrics), duration);
},
"Get serialized metrics for aggregate duration and actual duration "
"as (bytes, float)")
.def("start", &DataLoader::Start, "Start the data loader processing")
.def("stop", &DataLoader::Stop, "Stop the data loader");
// Expose TensorBase for potential advanced usage.
py::class_<TensorBase>(m, "TensorBase")
.def("shape", &TensorBase::shape, py::return_value_policy::reference)
.def("strides", &TensorBase::strides, py::return_value_policy::reference)
.def("element_size", &TensorBase::element_size)
.def("py_format", &TensorBase::py_format);
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_rescorer.cc
================================================
#include "loader/stages/chunk_rescorer.h"
#include <stdexcept>
#include <utility>
#include "absl/base/call_once.h"
#include "absl/log/log.h"
#include "chess/board.h"
#include "loader/data_loader_metrics.h"
namespace lczero {
namespace training {
namespace {
void V6ToV7(std::span<FrameType> data, float theta = 5.0f / 6.0f) {
if (data.empty()) return;
const float beta = 1.0f - theta;
float st_q = 0.0f;
float st_d = 0.0f;
// Iterate backwards to calculate EMA and lookaheads in one go.
for (size_t i = data.size(); i-- > 0;) {
FrameType& item = data[i];
// For Q, we operate in an alternating sign domain.
const float sign = (i % 2 != 0) ? -1.0f : 1.0f;
const float cur_q = item.root_q * sign;
if (i == data.size() - 1) {
st_q = cur_q;
st_d = item.root_d;
} else {
st_q = theta * st_q + beta * cur_q;
st_d = theta * st_d + beta * item.root_d;
}
item.version = 7;
item.q_st = st_q * sign;
item.d_st = std::max(st_d, 0.0f);
// Handle lookahead indices safely.
auto get_idx = [&](size_t offset) {
return (i + offset < data.size()) ? data[i + offset].played_idx : 65535;
};
item.opp_played_idx = get_idx(1);
item.next_played_idx = get_idx(2);
}
}
} // namespace
ChunkRescorer::ChunkRescorer(const ChunkRescorerConfig& config)
: SingleInputStage<ChunkRescorerConfig, InputType>(config),
SingleOutputStage<OutputType>(config.output()),
syzygy_paths_(config.syzygy_paths()),
dist_temp_(config.dist_temp()),
dist_offset_(config.dist_offset()),
dtz_boost_(config.dtz_boost()),
new_input_format_(config.new_input_format()),
thread_pool_(config.threads(), ThreadPoolOptions{}),
st_q_theta_(config.st_q_theta()) {
static absl::once_flag bitboards_initialized_flag;
absl::call_once(bitboards_initialized_flag, InitializeMagicBitboards);
if (config.has_deblunder_threshold() && config.has_deblunder_width()) {
RescorerDeblunderSetup(config.deblunder_threshold(),
config.deblunder_width());
}
if (config.has_gaviota_paths()) {
RescorerGaviotaSetup(std::string(config.gaviota_paths()));
}
LOG(INFO) << "Initializing ChunkRescorer with " << config.threads()
<< " worker thread(s)";
thread_contexts_.reserve(config.threads());
for (size_t i = 0; i < config.threads(); ++i) {
thread_contexts_.push_back(std::make_unique<ThreadContext>());
}
}
ChunkRescorer::~ChunkRescorer() { Stop(); }
void ChunkRescorer::InitializeTablebase() {
if (tablebase_initialized_) {
return;
}
LOG(INFO) << "ChunkRescorer initializing Syzygy tablebase with paths '"
<< syzygy_paths_ << "'.";
tablebase_initialized_ = tablebase_.init(syzygy_paths_);
if (tablebase_initialized_) {
LOG(INFO) << "ChunkRescorer Syzygy max cardinality: "
<< tablebase_.max_cardinality();
} else {
LOG(WARNING) << "ChunkRescorer failed to initialize Syzygy tablebase; "
"rescoring will continue without tablebase lookups.";
}
}
void ChunkRescorer::Start() {
LOG(INFO) << "Starting ChunkRescorer worker threads.";
InitializeTablebase();
for (size_t i = 0; i < thread_contexts_.size(); ++i) {
thread_pool_.Enqueue([this, i](std::stop_token stop_token) {
Worker(stop_token, thread_contexts_[i].get());
});
}
}
void ChunkRescorer::Stop() {
if (thread_pool_.stop_token().stop_requested()) return;
LOG(INFO) << "Stopping ChunkRescorer.";
thread_pool_.Shutdown();
output_queue()->Close();
LOG(INFO) << "ChunkRescorer stopped.";
}
void ChunkRescorer::Worker(std::stop_token stop_token, ThreadContext* context) {
auto producer = output_queue()->CreateProducer();
try {
while (true) {
TrainingChunk chunk = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return input_queue()->Get(stop_token);
}();
try {
chunk.frames = RescoreTrainingData<FrameType>(
chunk.frames, &tablebase_, dist_temp_, dist_offset_, dtz_boost_,
new_input_format_);
if (chunk.frames.at(0).version == 6) V6ToV7(chunk.frames, st_q_theta_);
LoadMetricPauser pauser(context->load_metric_updater);
producer.Put(std::move(chunk), stop_token);
} catch (const std::exception& exception) {
failed_rescores_.fetch_add(1, std::memory_order_acq_rel);
LOG(ERROR) << "ChunkRescorer failed to rescore chunk: "
<< exception.what() << "; sort_key=" << chunk.sort_key
<< "; index_within_sort_key=" << chunk.index_within_sort_key
<< "; global_index=" << chunk.global_index
<< "; use_count=" << chunk.use_count
<< "; frame_count=" << chunk.frames.size();
continue;
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "ChunkRescorer worker stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "ChunkRescorer worker stopping, request cancelled.";
}
}
StageMetricProto ChunkRescorer::FlushMetrics() {
StageMetricProto stage_metric;
LoadMetricProto aggregated_load;
aggregated_load.set_name("load");
for (const auto& context : thread_contexts_) {
UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(aggregated_load);
auto* failed = stage_metric.add_count_metrics();
failed->set_name("failed_rescores");
failed->set_count(failed_rescores_.exchange(0, std::memory_order_acq_rel));
*stage_metric.add_queue_metrics() =
MetricsFromQueue("output", *output_queue());
return stage_metric;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_rescorer.h
================================================
// ABOUTME: Stage that rescales training chunks using Syzygy tablebases.
// ABOUTME: Adjusts frame metadata by invoking the classic LCZero rescorer.
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <stop_token>
#include <string>
#include <vector>
#include "libs/lc0/src/syzygy/syzygy.h"
#include "libs/lc0/src/trainingdata/rescorer.h"
#include "loader/frame_type.h"
#include "loader/stages/stage.h"
#include "loader/stages/training_chunk.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Stage that takes TrainingChunk objects, applies tablebase-based rescoring and
// forwards the updated chunks downstream.
class ChunkRescorer
: public SingleInputStage<ChunkRescorerConfig, TrainingChunk>,
public SingleOutputStage<TrainingChunk> {
public:
using InputType = TrainingChunk;
using OutputType = TrainingChunk;
explicit ChunkRescorer(const ChunkRescorerConfig& config);
~ChunkRescorer() override;
void Start() override;
void Stop() override;
StageMetricProto FlushMetrics() override;
private:
struct ThreadContext {
LoadMetricUpdater load_metric_updater;
};
void Worker(std::stop_token stop_token, ThreadContext* context);
void InitializeTablebase();
SyzygyTablebase tablebase_;
bool tablebase_initialized_ = false;
std::string syzygy_paths_;
float dist_temp_;
float dist_offset_;
float dtz_boost_;
int new_input_format_;
ThreadPool thread_pool_;
std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;
std::atomic<uint64_t> failed_rescores_{0};
float st_q_theta_;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_rescorer_test.cc
================================================
#include "loader/stages/chunk_rescorer.h"
#include <string>
#include <utility>
#include <vector>
#include "gtest/gtest.h"
#include "loader/stages/training_chunk.h"
#include "proto/data_loader_config.pb.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
namespace {
template <typename T>
class PassthroughStage : public Stage {
public:
explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
void Start() override {}
void Stop() override {}
StageMetricProto FlushMetrics() override { return StageMetricProto(); }
QueueBase* GetOutput(std::string_view name = "") override {
(void)name;
return queue_;
}
void SetInputs(absl::Span<QueueBase* const> inputs) override {
if (!inputs.empty()) {
throw std::runtime_error("PassthroughStage expects no inputs");
}
}
private:
Queue<T>* queue_;
};
} // namespace
class ChunkRescorerTest : public ::testing::Test {
protected:
void SetUp() override {
input_queue_ = std::make_unique<Queue<TrainingChunk>>(10);
config_.set_threads(1);
config_.mutable_output()->set_queue_capacity(10);
config_.set_syzygy_paths("");
config_.set_dist_temp(0.75f);
config_.set_dist_offset(0.1f);
config_.set_dtz_boost(0.2f);
config_.set_new_input_format(-1);
}
TrainingChunk MakeChunk(std::vector<FrameType> frames,
std::string sort_key = "alpha", size_t index = 3,
uint32_t use = 7) {
TrainingChunk chunk;
chunk.sort_key = std::move(sort_key);
chunk.index_within_sort_key = index;
chunk.use_count = use;
chunk.frames = std::move(frames);
return chunk;
}
std::unique_ptr<Queue<TrainingChunk>> input_queue_;
ChunkRescorerConfig config_;
};
TEST_F(ChunkRescorerTest, HandlesInputQueueClosure) {
ChunkRescorer rescorer(config_);
rescorer.SetInputs({input_queue_.get()});
rescorer.Start();
input_queue_->Close();
EXPECT_THROW(rescorer.output_queue()->Get(), QueueClosedException);
rescorer.Stop();
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_loader.cc
================================================
#include "loader/stages/chunk_source_loader.h"
#include <filesystem>
#include <utility>
#include "absl/log/log.h"
#include "loader/chunk_source/rawfile_chunk_source.h"
#include "loader/chunk_source/tar_chunk_source.h"
#include "loader/data_loader_metrics.h"
#include "proto/data_loader_config.pb.h"
namespace lczero {
namespace training {
std::unique_ptr<ChunkSource> CreateChunkSourceFromFile(
const std::filesystem::path& filepath,
ChunkSourceLoaderConfig::FrameFormat frame_format) {
auto extension = filepath.extension();
try {
if (extension == ".gz") {
return std::make_unique<RawFileChunkSource>(filepath, frame_format);
}
if (extension == ".tar") {
return std::make_unique<TarChunkSource>(filepath, frame_format);
}
} catch (const std::exception& e) {
LOG(ERROR) << "Failed to create chunk source for " << filepath << ": "
<< e.what();
return nullptr;
}
return nullptr;
}
ChunkSourceLoader::ChunkSourceLoader(const ChunkSourceLoaderConfig& config)
: SingleInputStage<ChunkSourceLoaderConfig, InputType>(config),
SingleOutputStage<OutputType>(config.output()),
thread_pool_(config.threads(), ThreadPoolOptions{}),
frame_format_(config.frame_format()) {
LOG(INFO) << "Initializing ChunkSourceLoader with " << config.threads()
<< " worker threads";
// Initialize thread contexts but don't start worker threads yet.
thread_contexts_.reserve(config.threads());
for (size_t i = 0; i < config.threads(); ++i) {
thread_contexts_.push_back(std::make_unique<ThreadContext>());
}
}
ChunkSourceLoader::~ChunkSourceLoader() { Stop(); }
void ChunkSourceLoader::Start() {
LOG(INFO) << "Starting ChunkSourceLoader worker threads.";
for (size_t i = 0; i < thread_contexts_.size(); ++i) {
thread_pool_.Enqueue([this, i](std::stop_token stop_token) {
Worker(stop_token, thread_contexts_[i].get());
});
}
}
void ChunkSourceLoader::Stop() {
if (thread_pool_.stop_token().stop_requested()) return;
LOG(INFO) << "Stopping ChunkSourceLoader.";
thread_pool_.Shutdown();
output_queue()->Close();
LOG(INFO) << "ChunkSourceLoader stopped.";
}
void ChunkSourceLoader::Worker(std::stop_token stop_token,
ThreadContext* context) {
auto producer = output_queue()->CreateProducer();
LOG(INFO) << "ChunkSourceLoader worker@" << static_cast<const void*>(context)
<< " started.";
try {
while (true) {
auto file = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return input_queue()->Get(stop_token);
}();
if (file.message_type ==
FilePathProvider::MessageType::kInitialScanComplete) {
LOG(INFO)
<< "ChunkSourceLoader received initial scan completion marker.";
bool should_forward;
{
absl::MutexLock lock(&phase_mutex_);
sentinel_received_ = true;
should_forward = (pre_sentinel_work_count_ == 0);
}
if (should_forward) {
LOG(INFO) << "ChunkSourceLoader forwarding initial scan completion "
"marker.";
producer.Put({.source = nullptr, .message_type = file.message_type},
stop_token);
}
continue;
}
// Track pre-sentinel work.
bool is_pre_sentinel;
{
absl::MutexLock lock(&phase_mutex_);
is_pre_sentinel = !sentinel_received_;
if (is_pre_sentinel) pre_sentinel_work_count_++;
}
// Create ChunkSource from the file.
LOG_EVERY_N(INFO, 1000)
<< "ChunkSourceLoader preparing chunk source for " << file.filepath;
auto source = CreateChunkSourceFromFile(file.filepath, frame_format_);
if (source) {
{
absl::MutexLock lock(&last_chunk_key_mutex_);
last_chunk_key_ = source->GetChunkSortKey();
}
ChunkSourceWithPhase output{.source = std::move(source),
.message_type = file.message_type};
LoadMetricPauser pauser(context->load_metric_updater);
producer.Put(std::move(output), stop_token);
} else {
LOG_EVERY_N(INFO, 100)
<< "ChunkSourceLoader skipping unsupported file: " << file.filepath;
skipped_files_count_++;
}
// Complete pre-sentinel work tracking.
if (is_pre_sentinel) {
absl::MutexLock lock(&phase_mutex_);
if (--pre_sentinel_work_count_ == 0 && sentinel_received_) {
LOG(INFO) << "ChunkSourceLoader forwarding initial scan completion "
"marker after all pre-sentinel work completed.";
producer.Put(
{.source = nullptr,
.message_type =
FilePathProvider::MessageType::kInitialScanComplete},
stop_token);
}
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "ChunkSourceLoader worker@"
<< static_cast<const void*>(context)
<< " stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "ChunkSourceLoader worker@"
<< static_cast<const void*>(context)
<< " stopping, request cancelled.";
} catch (const std::exception& e) {
LOG(ERROR) << "ChunkSourceLoader worker@"
<< static_cast<const void*>(context)
<< " exiting due to exception: " << e.what();
throw;
}
LOG(INFO) << "ChunkSourceLoader worker@" << static_cast<const void*>(context)
<< " exiting loop.";
}
StageMetricProto ChunkSourceLoader::FlushMetrics() {
StageMetricProto stage_metric;
LoadMetricProto aggregated_load;
aggregated_load.set_name("load");
for (const auto& context : thread_contexts_) {
UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(aggregated_load);
auto* skipped_metric = stage_metric.add_count_metrics();
skipped_metric->set_name("skipped_files");
skipped_metric->set_count(skipped_files_count_.exchange(0));
// Get the last chunk key.
{
absl::MutexLock lock(&last_chunk_key_mutex_);
if (!last_chunk_key_.empty()) {
stage_metric.set_last_chunk_key(last_chunk_key_);
}
}
*stage_metric.add_queue_metrics() =
MetricsFromQueue("output", *output_queue());
return stage_metric;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_loader.h
================================================
#pragma once
#include <atomic>
#include <filesystem>
#include <memory>
#include <stop_token>
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "loader/chunk_source/chunk_source.h"
#include "loader/stages/file_path_provider.h"
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Creates a ChunkSource based on file extension. Returns RawFileChunkSource for
// .gz files, TarChunkSource for .tar files, or nullptr for unsupported types.
std::unique_ptr<ChunkSource> CreateChunkSourceFromFile(
const std::filesystem::path& filepath,
ChunkSourceLoaderConfig::FrameFormat frame_format);
struct ChunkSourceWithPhase {
std::unique_ptr<ChunkSource> source;
FilePathProvider::MessageType message_type;
};
// Worker pool that converts FilePathProvider output to ChunkSource objects.
// Takes FilePathProvider::File as input and outputs ChunkSourceWithPhase.
class ChunkSourceLoader
: public SingleInputStage<ChunkSourceLoaderConfig, FilePathProvider::File>,
public SingleOutputStage<ChunkSourceWithPhase> {
public:
using InputType = FilePathProvider::File;
using OutputType = ChunkSourceWithPhase;
explicit ChunkSourceLoader(const ChunkSourceLoaderConfig& config);
~ChunkSourceLoader();
void Start() override;
void Stop() override;
StageMetricProto FlushMetrics() override;
private:
struct ThreadContext {
LoadMetricUpdater load_metric_updater;
};
void Worker(std::stop_token stop_token, ThreadContext* context);
ThreadPool thread_pool_;
std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;
std::atomic<uint64_t> skipped_files_count_{0};
absl::Mutex last_chunk_key_mutex_;
std::string last_chunk_key_;
ChunkSourceLoaderConfig::FrameFormat frame_format_;
// Synchronization for sentinel barrier.
absl::Mutex phase_mutex_;
int pre_sentinel_work_count_ ABSL_GUARDED_BY(phase_mutex_) = 0;
bool sentinel_received_ ABSL_GUARDED_BY(phase_mutex_) = false;
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_loader_test.cc
================================================
#include "loader/stages/chunk_source_loader.h"
#include <gtest/gtest.h>
#include <filesystem>
#include "loader/stages/file_path_provider.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
namespace {
template <typename T>
class PassthroughStage : public Stage {
public:
explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
void Start() override {}
void Stop() override {}
StageMetricProto FlushMetrics() override { return StageMetricProto(); }
QueueBase* GetOutput(std::string_view name = "") override {
(void)name;
return queue_;
}
void SetInputs(absl::Span<QueueBase* const> inputs) override {
if (!inputs.empty()) {
throw std::runtime_error("PassthroughStage expects no inputs");
}
}
private:
Queue<T>* queue_;
};
} // namespace
TEST(ChunkSourceLoaderTest, ProcessesFiles) {
Queue<FilePathProvider::File> input_queue(10);
ChunkSourceLoaderConfig config;
config.set_threads(1);
config.mutable_output()->set_queue_capacity(10);
ChunkSourceLoader feed(config);
feed.SetInputs({&input_queue});
feed.Start();
{
auto producer = input_queue.CreateProducer();
// Add a file with unsupported extension (should not create ChunkSource)
producer.Put(FilePathProvider::File{
.filepath =
std::filesystem::path("/test.txt"), // unsupported extension
.message_type = FilePathProvider::MessageType::kFile});
} // Producer destroyed here, closing input queue
// Try to get output - there should be no valid ChunkSources for unsupported
// files
try {
while (true) {
auto output = feed.output_queue()->Get();
// If we get output, it means a ChunkSource was created, which shouldn't
// happen for unsupported files
FAIL() << "Expected no output for unsupported file extension";
}
} catch (const QueueClosedException&) {
// Expected: queue should be closed when input is done and no output
// produced
SUCCEED();
}
}
TEST(ChunkSourceLoaderTest, HandlesPhases) {
Queue<FilePathProvider::File> input_queue(10);
ChunkSourceLoaderConfig config;
config.set_threads(1);
config.mutable_output()->set_queue_capacity(10);
ChunkSourceLoader feed(config);
feed.SetInputs({&input_queue});
feed.Start();
{
auto producer = input_queue.CreateProducer();
// Test different phases - all should be passed through even if no
// ChunkSource is created
producer.Put(FilePathProvider::File{
.filepath = std::filesystem::path("/test1.gz"),
.message_type = FilePathProvider::MessageType::kFile});
producer.Put(FilePathProvider::File{
.filepath = std::filesystem::path("/test2.gz"),
.message_type = FilePathProvider::MessageType::kFile});
} // Producer destroyed here, closing input queue
// Queue should eventually close when input is done
try {
while (true) {
feed.output_queue()->Get();
}
} catch (const QueueClosedException&) {
SUCCEED();
}
}
TEST(ChunkSourceLoaderTest, PassesThroughInitialScanComplete) {
Queue<FilePathProvider::File> input_queue(10);
ChunkSourceLoaderConfig config;
config.set_threads(1);
config.mutable_output()->set_queue_capacity(10);
ChunkSourceLoader feed(config);
feed.SetInputs({&input_queue});
feed.Start();
{
auto producer = input_queue.CreateProducer();
producer.Put(FilePathProvider::File{
.filepath = std::filesystem::path(""),
.message_type = FilePathProvider::MessageType::kInitialScanComplete});
} // Producer destroyed here, closing input queue
// Should get kInitialScanComplete in output with null ChunkSource
auto output = feed.output_queue()->Get();
EXPECT_EQ(output.message_type,
FilePathProvider::MessageType::kInitialScanComplete);
EXPECT_EQ(output.source, nullptr);
// Queue should be closed after the single message
try {
feed.output_queue()->Get();
FAIL() << "Expected queue to be closed";
} catch (const QueueClosedException&) {
SUCCEED();
}
}
TEST(ChunkSourceLoaderTest, SentinelBarrierWithMultipleThreads) {
Queue<FilePathProvider::File> input_queue(100);
ChunkSourceLoaderConfig config;
config.set_threads(4);
config.mutable_output()->set_queue_capacity(100);
ChunkSourceLoader feed(config);
feed.SetInputs({&input_queue});
feed.Start();
{
auto producer = input_queue.CreateProducer();
// Add files that will be processed before sentinel.
for (int i = 0; i < 20; ++i) {
producer.Put(FilePathProvider::File{
.filepath =
std::filesystem::path("/test" + std::to_string(i) + ".txt"),
.message_type = FilePathProvider::MessageType::kFile});
}
// Add sentinel.
producer.Put(FilePathProvider::File{
.filepath = std::filesystem::path(""),
.message_type = FilePathProvider::MessageType::kInitialScanComplete});
// Add files that arrive after sentinel.
for (int i = 20; i < 30; ++i) {
producer.Put(FilePathProvider::File{
.filepath =
std::filesystem::path("/test" + std::to_string(i) + ".txt"),
.message_type = FilePathProvider::MessageType::kFile});
}
} // Producer destroyed here, closing input queue
// Read all outputs and verify sentinel comes after all pre-sentinel files.
int files_before_sentinel = 0;
int files_after_sentinel = 0;
bool sentinel_seen = false;
try {
while (true) {
auto output = feed.output_queue()->Get();
if (output.message_type ==
FilePathProvider::MessageType::kInitialScanComplete) {
EXPECT_FALSE(sentinel_seen) << "Sentinel should appear exactly once";
sentinel_seen = true;
} else {
if (sentinel_seen) {
files_after_sentinel++;
} else {
files_before_sentinel++;
}
}
}
} catch (const QueueClosedException&) {
}
// Verify sentinel was seen.
EXPECT_TRUE(sentinel_seen);
// All 20 pre-sentinel files should be before sentinel (unsupported, so 0).
EXPECT_EQ(files_before_sentinel, 0);
// All 10 post-sentinel files should be after sentinel (unsupported, so 0).
EXPECT_EQ(files_after_sentinel, 0);
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_splitter.cc
================================================
#include "loader/stages/chunk_source_splitter.h"
#include <algorithm>
#include <numeric>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "loader/data_loader_metrics.h"
namespace lczero {
namespace training {
ChunkSourceSplitter::ChunkSourceSplitter(
const ChunkSourceSplitterConfig& config)
: SingleInputStage<ChunkSourceSplitterConfig, InputType>(config) {
if (config.output().empty()) {
throw std::runtime_error("ChunkSourceSplitter requires at least 1 output.");
}
// Validate parallel arrays have same size.
if (config.output_size() != config.weight_size()) {
throw std::runtime_error(absl::StrCat(
"ChunkSourceSplitter output and weight arrays must have same size: ",
config.output_size(), " vs ", config.weight_size()));
}
// Create output queues from parallel arrays.
outputs_.reserve(config.output_size());
for (size_t i = 0; i < static_cast<size_t>(config.output_size()); ++i) {
const auto& queue_cfg = config.output(static_cast<int>(i));
const uint64_t weight = i < static_cast<size_t>(config.weight_size())
? config.weight(static_cast<int>(i))
: 1;
if (absl::c_any_of(outputs_, [&](const auto& existing_out) {
return existing_out->name == queue_cfg.name();
})) {
throw std::runtime_error(std::string(absl::StrCat(
"Duplicate output name in ChunkSourceSplitter: ", queue_cfg.name())));
}
auto* out = outputs_
.emplace_back(std::make_unique<Output>(
queue_cfg.name(), weight, queue_cfg.queue_capacity(),
ToOverflowBehavior(queue_cfg.overflow_behavior())))
.get();
LOG(INFO) << "ChunkSourceSplitter configured output '" << out->name
<< "' weight=" << out->weight
<< " capacity=" << queue_cfg.queue_capacity();
}
// Precompute cumulative weights for fast assignment.
cumulative_.resize(outputs_.size());
std::transform_inclusive_scan(
outputs_.begin(), outputs_.end(), cumulative_.begin(),
std::plus<uint64_t>{},
[](const std::unique_ptr<Output>& out) { return out->weight; });
// Validate total weight is positive.
if (cumulative_.back() == 0) {
throw std::runtime_error(
"ChunkSourceSplitter requires at least one output with positive "
"weight.");
}
}
ChunkSourceSplitter::~ChunkSourceSplitter() { Stop(); }
void ChunkSourceSplitter::Start() {
LOG(INFO) << "Starting ChunkSourceSplitter worker.";
thread_pool_.Enqueue(
[this](std::stop_token stop_token) { Worker(stop_token); });
}
void ChunkSourceSplitter::Stop() {
if (thread_pool_.stop_token().stop_requested()) return;
LOG(INFO) << "Stopping ChunkSourceSplitter.";
thread_pool_.Shutdown();
for (auto& out : outputs_) out->queue.Close();
}
QueueBase* ChunkSourceSplitter::GetOutput(std::string_view name) {
auto iter = absl::c_find_if(
outputs_, [&](const auto& out) { return out->name == name; });
if (iter == outputs_.end()) {
throw std::runtime_error(
absl::StrCat("Unknown output '", name, "' for ChunkSourceSplitter."));
}
return &(*iter)->queue;
}
StageMetricProto ChunkSourceSplitter::FlushMetrics() {
StageMetricProto metric;
for (auto& out : outputs_) {
*metric.add_queue_metrics() = MetricsFromQueue(out->name, out->queue);
}
return metric;
}
void ChunkSourceSplitter::Worker(std::stop_token stop_token) {
// Create producers for each output in this thread.
std::vector<Queue<OutputType>::Producer> producers;
producers.reserve(outputs_.size());
for (auto& out : outputs_) {
producers.emplace_back(out->queue.CreateProducer());
}
try {
while (true) {
InputType item = input_queue()->Get(stop_token);
if (item.message_type ==
FilePathProvider::MessageType::kInitialScanComplete) {
// Broadcast to all outputs.
for (auto& prod : producers) {
prod.Put(
OutputType{.source = nullptr, .message_type = item.message_type},
stop_token);
}
continue;
}
// Share ownership of the ChunkSource with any produced views.
std::shared_ptr<ChunkSource> shared_source(std::move(item.source));
auto per_output_indices = BuildAssignments(shared_source);
// Emit only non-empty views, preserving original message type (kFile).
for (size_t i = 0; i < outputs_.size(); ++i) {
if (per_output_indices[i].empty()) continue;
auto view = std::make_unique<ChunkSourceView>(
shared_source, std::move(per_output_indices[i]));
producers[i].Put(
OutputType{.source = std::move(view),
.message_type = FilePathProvider::MessageType::kFile},
stop_token);
}
}
} catch (const QueueClosedException&) {
// Input queue closed — producers will close queues automatically when
// destroyed if this thread holds the last producer.
LOG(INFO) << "ChunkSourceSplitter worker exiting: input closed.";
}
}
std::vector<std::vector<uint32_t>> ChunkSourceSplitter::BuildAssignments(
const std::shared_ptr<ChunkSource>& source) {
const std::string sort_key = source->GetChunkSortKey();
const size_t n = source->GetChunkCount();
// Prepare result containers with a rough reservation.
std::vector<std::vector<uint32_t>> indices(outputs_.size());
for (size_t i = 0; i < n; ++i) {
const uint64_t h =
static_cast<uint64_t>(absl::Hash<std::pair<std::string, size_t>>{}(
std::make_pair(sort_key, i)));
const uint64_t r = h % cumulative_.back();
// Find the output where cumulative[j-1] <= r < cumulative[j].
const auto it = std::upper_bound(cumulative_.begin(), cumulative_.end(), r);
const size_t idx = it - cumulative_.begin();
indices[idx].push_back(static_cast<uint32_t>(i));
}
return indices;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_splitter.h
================================================
#pragma once
#include <memory>
#include <stop_token>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/hash/hash.h"
#include "absl/log/log.h"
#include "loader/chunk_source/chunk_source.h"
#include "loader/chunk_source/chunk_source_view.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Splits an incoming ChunkSource into several ChunkSourceViews based on a
// deterministic hash of (sort_key, index). Emits to multiple named outputs.
class ChunkSourceSplitter
: public SingleInputStage<ChunkSourceSplitterConfig, ChunkSourceWithPhase> {
public:
using InputType = ChunkSourceWithPhase;
using OutputType = ChunkSourceWithPhase;
explicit ChunkSourceSplitter(const ChunkSourceSplitterConfig& config);
~ChunkSourceSplitter();
void Start() override;
void Stop() override;
StageMetricProto FlushMetrics() override;
QueueBase* GetOutput(std::string_view name = "") override;
private:
struct Output {
std::string name;
uint64_t weight;
Queue<OutputType> queue;
Output(std::string_view name, uint64_t weight, size_t capacity,
OverflowBehavior overflow)
: name(name), weight(weight), queue(capacity, overflow) {}
};
void Worker(std::stop_token stop_token);
// Builds per-output indices given a source; uses absl::Hash on
// (sort_key, index) and weights to assign indices.
std::vector<std::vector<uint32_t>> BuildAssignments(
const std::shared_ptr<ChunkSource>& source);
std::vector<std::unique_ptr<Output>> outputs_;
std::vector<uint64_t> cumulative_;
ThreadPool thread_pool_{1};
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_source_splitter_test.cc
================================================
#include "loader/stages/chunk_source_splitter.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/hash/hash.h"
#include "gtest/gtest.h"
#include "loader/chunk_source/chunk_source.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
namespace {
// Simple fixed-count chunk source for testing.
class FixedCountChunkSource : public ChunkSource {
public:
FixedCountChunkSource(std::string sort_key, size_t count)
: key_(std::move(sort_key)), count_(count) {}
private:
std::string GetChunkSortKey() const override { return key_; }
size_t GetChunkCount() const override { return count_; }
std::optional<std::vector<FrameType>> GetChunkData(size_t) override {
return std::vector<FrameType>{FrameType{}};
}
std::string key_;
size_t count_;
};
template <typename T>
class PassthroughStage : public Stage {
public:
explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
void Start() override {}
void Stop() override {}
StageMetricProto FlushMetrics() override { return StageMetricProto(); }
QueueBase* GetOutput(std::string_view) override { return queue_; }
void SetInputs(absl::Span<QueueBase* const> inputs) override {
if (!inputs.empty()) {
throw std::runtime_error("PassthroughStage expects no inputs");
}
}
private:
Queue<T>* queue_;
};
} // namespace
TEST(ChunkSourceSplitterTest, SplitsByHashAndWeight) {
// Upstream queue.
auto input_queue = std::make_unique<Queue<ChunkSourceWithPhase>>(8);
// Configure splitter with two outputs A:1, B:2.
ChunkSourceSplitterConfig cfg;
auto* outA = cfg.add_output();
outA->set_name("A");
outA->set_queue_capacity(8);
cfg.add_weight(1);
auto* outB = cfg.add_output();
outB->set_name("B");
outB->set_queue_capacity(8);
cfg.add_weight(2);
ChunkSourceSplitter splitter(cfg);
splitter.SetInputs({input_queue.get()});
splitter.Start();
// Send a source with known key and count.
const std::string key = "skey";
const size_t count = 100;
ChunkSourceWithPhase item;
item.source = std::make_unique<FixedCountChunkSource>(key, count);
item.message_type = FilePathProvider::MessageType::kFile;
auto producer = input_queue->CreateProducer();
producer.Put(std::move(item));
producer.Close();
// Compute expected assignment counts using the same hash/weights.
const uint64_t total_weight = 3;
uint64_t cumA = 1; // [0]
size_t expectedA = 0;
size_t expectedB = 0;
for (size_t i = 0; i < count; ++i) {
const uint64_t h = static_cast<uint64_t>(
absl::Hash<std::pair<std::string, size_t>>{}(std::make_pair(key, i)));
const uint64_t r = h % total_weight;
if (r < cumA)
++expectedA;
else
++expectedB;
}
// Read outputs and verify view sizes.
auto* qa =
dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput("A"));
auto* qb =
dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput("B"));
ASSERT_NE(qa, nullptr);
ASSERT_NE(qb, nullptr);
auto msgA = qa->Get();
auto msgB = qb->Get();
ASSERT_NE(msgA.source, nullptr);
ASSERT_NE(msgB.source, nullptr);
EXPECT_EQ(msgA.source->GetChunkCount(), expectedA);
EXPECT_EQ(msgB.source->GetChunkCount(), expectedB);
splitter.Stop();
}
TEST(ChunkSourceSplitterTest, BroadcastsInitialScanComplete) {
auto input_queue = std::make_unique<Queue<ChunkSourceWithPhase>>(4);
ChunkSourceSplitterConfig cfg;
auto* outA = cfg.add_output();
outA->set_name("A");
cfg.add_weight(1);
auto* outB = cfg.add_output();
outB->set_name("B");
cfg.add_weight(1);
ChunkSourceSplitter splitter(cfg);
splitter.SetInputs({input_queue.get()});
splitter.Start();
ChunkSourceWithPhase marker;
marker.source = nullptr;
marker.message_type = FilePathProvider::MessageType::kInitialScanComplete;
auto producer = input_queue->CreateProducer();
producer.Put(std::move(marker));
producer.Close();
auto* qa =
dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput("A"));
auto* qb =
dynamic_cast<Queue<ChunkSourceWithPhase>*>(splitter.GetOutput("B"));
auto m1 = qa->Get();
auto m2 = qb->Get();
EXPECT_EQ(m1.message_type,
FilePathProvider::MessageType::kInitialScanComplete);
EXPECT_EQ(m2.message_type,
FilePathProvider::MessageType::kInitialScanComplete);
EXPECT_EQ(m1.source, nullptr);
EXPECT_EQ(m2.source, nullptr);
splitter.Stop();
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_unpacker.cc
================================================
#include "loader/stages/chunk_unpacker.h"
#include <absl/algorithm/container.h>
#include <absl/container/flat_hash_set.h>
#include <absl/log/check.h>
#include <absl/log/log.h>
#include <absl/random/random.h>
#include <absl/random/seed_sequences.h>
#include <algorithm>
#include <cstdint>
#include <numeric>
#include <random>
#include <span>
#include <utility>
#include <vector>
#include "absl/numeric/int128.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "loader/data_loader_metrics.h"
#include "loader/stages/position_sampling.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
namespace lczero {
namespace training {
// Deterministically partitions `n` positions into disjoint subsets of size
// `~p*n`, returning the subset for a given `iteration`. While each selection
// individually behaves like a Bernoulli sample with probability `p`, the
// samples are correlated to ensure all positions are selected exactly once over
// `1/p` iterations. To sample disjoint subsets from the same set of positions,
// `gen` must be seeded identically for each call with a different `iteration`.
std::vector<uint32_t> PickSampledPositions(int32_t n, double p,
int32_t iteration,
absl::BitGen& gen) {
assert(p > 0.0 && p <= 1.0);
double carried_prob = p;
std::vector<uint32_t> result;
absl::flat_hash_set<int32_t> skip_next_round;
while (true) {
int32_t num_this_round = (1.0 - carried_prob) / p + 1;
double last_partial_prob = 1 - (carried_prob + (num_this_round - 1) * p);
const bool return_this_round = iteration < num_this_round;
absl::flat_hash_set<int32_t> skip_this_round(skip_next_round);
skip_next_round.clear();
for (int32_t i = 0; i < n; ++i) {
if (skip_this_round.contains(i)) continue;
const double toss = absl::Uniform<double>(gen, 0.0, 1.0);
const int32_t value = (toss - carried_prob) / p + 1;
if (value == iteration) result.push_back(static_cast<uint32_t>(i));
if (value >= num_this_round) {
skip_next_round.insert(static_cast<int32_t>(i));
}
}
if (return_this_round) return result;
iteration -= num_this_round;
carried_prob = p - last_partial_prob;
}
}
// Samples `k` indices from an infinite, deterministically-generated sequence,
// after an initial `skip`. The sequence is formed by repeatedly shuffling
// `[0..n-1]` and probabilistically selecting each index. Determinism is
// achieved by seeding a new PRNG for each shuffled block from a stable
// `root_seed` and the block's index.
std::vector<uint32_t> SampleProbabilisticSequence(
uint64_t k, uint64_t skip, std::span<const float> probabilities,
absl::BitGen& gen) {
const size_t n = probabilities.size();
if (n == 0 || k == 0) return {};
uint64_t skipped_so_far = 0;
std::vector<uint32_t> v(n);
std::iota(v.begin(), v.end(), 0u);
std::vector<uint32_t> result;
result.reserve(k);
while (true) {
std::shuffle(v.begin(), v.end(), gen);
for (const uint32_t candidate : v) {
if (!absl::Bernoulli(gen, probabilities[candidate])) continue;
if (skipped_so_far < skip) {
skipped_so_far++;
} else {
result.push_back(candidate);
if (result.size() == k) return result;
}
}
}
}
namespace {
uint32_t GenerateRunSeed() {
absl::BitGen gen(absl::MakeSeedSeq());
return absl::Uniform<uint32_t>(gen);
}
} // namespace
ChunkUnpacker::ChunkUnpacker(const ChunkUnpackerConfig& config)
: SingleInputStage<ChunkUnpackerConfig, InputType>(config),
config_(config),
run_seed_(GenerateRunSeed()),
primary_output_queue_(
config.output().queue_capacity(),
ToOverflowBehavior(config.output().overflow_behavior())),
thread_pool_(config.threads(), ThreadPoolOptions{}) {
const bool has_rate = config.has_position_sampling_rate();
const bool has_count = config.has_position_count();
const bool has_prefetch_count = config.has_prefetch_count();
const bool has_prefetch_output = config.has_prefetch_output();
CHECK(has_prefetch_count == has_prefetch_output)
<< "prefetch_count and prefetch_output must both be set or both unset.";
if (has_prefetch_count) {
CHECK(has_count) << "position_count must be set when using prefetch mode.";
CHECK(!has_rate)
<< "position_sampling_rate cannot be used in prefetch mode.";
CHECK(config.position_count() == 1)
<< "position_count must equal 1 in prefetch mode, got "
<< config.position_count();
prefetch_output_queue_.emplace(
config.prefetch_output().queue_capacity(),
ToOverflowBehavior(config.prefetch_output().overflow_behavior()));
if (config.output().name() == config.prefetch_output().name()) {
throw std::runtime_error(
absl::StrCat("ChunkUnpacker output names must be different, got: '",
config.output().name(), "'"));
}
} else {
CHECK(has_rate != has_count)
<< "Exactly one of position_sampling_rate or position_count must be "
"set.";
}
LOG(INFO) << "Initializing ChunkUnpacker with " << config.threads()
<< " worker threads";
// Initialize thread contexts but don't start worker threads yet.
thread_contexts_.reserve(config.threads());
for (size_t i = 0; i < config.threads(); ++i) {
thread_contexts_.push_back(std::make_unique<ThreadContext>());
}
}
ChunkUnpacker::~ChunkUnpacker() { Stop(); }
void ChunkUnpacker::Start() {
LOG(INFO) << "Starting ChunkUnpacker worker threads.";
for (size_t i = 0; i < thread_contexts_.size(); ++i) {
thread_pool_.Enqueue([this, i](std::stop_token stop_token) {
Worker(stop_token, thread_contexts_[i].get());
});
}
}
void ChunkUnpacker::Stop() {
if (thread_pool_.stop_token().stop_requested()) return;
LOG(INFO) << "Stopping ChunkUnpacker.";
thread_pool_.Shutdown();
primary_output_queue_.Close();
if (prefetch_output_queue_) prefetch_output_queue_->Close();
LOG(INFO) << "ChunkUnpacker stopped.";
}
QueueBase* ChunkUnpacker::GetOutput(std::string_view name) {
if (name == config_.output().name()) return &primary_output_queue_;
if (config_.has_prefetch_output() &&
name == config_.prefetch_output().name()) {
return &*prefetch_output_queue_;
}
std::string available = absl::StrCat("'", config_.output().name(), "'");
if (config_.has_prefetch_output()) {
absl::StrAppend(&available, ", '", config_.prefetch_output().name(), "'");
}
throw std::runtime_error(absl::StrCat("ChunkUnpacker unknown output '", name,
"'. Available outputs: ", available));
}
namespace {
std::vector<float> FramesToProbabilities(std::span<const FrameType> frames,
const PositionSamplingConfig& config) {
std::vector<float> probabilities;
probabilities.reserve(frames.size());
absl::c_transform(frames, std::back_inserter(probabilities),
[&](const FrameType& frame) {
return ComputePositionSamplingWeight(frame, config);
});
const float max_prob = *absl::c_max_element(probabilities);
if (max_prob > 0.0f) {
absl::c_transform(probabilities, probabilities.begin(),
[max_prob](float p) { return p / max_prob; });
} else {
absl::c_fill(probabilities, 1.0f);
}
return probabilities;
}
} // namespace
void ChunkUnpacker::Worker(std::stop_token stop_token, ThreadContext* context) {
// Create a local producer for this worker thread.
auto primary_producer = primary_output_queue_.CreateProducer();
std::optional<decltype(prefetch_output_queue_->CreateProducer())>
prefetch_producer;
if (prefetch_output_queue_.has_value()) {
prefetch_producer.emplace(prefetch_output_queue_->CreateProducer());
}
try {
while (true) {
auto chunk = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return input_queue()->Get(stop_token);
}();
absl::BitGen gen(
std::seed_seq{run_seed_, static_cast<uint32_t>(chunk.global_index)});
std::vector<uint32_t> positions;
if (config_.has_position_sampling_rate()) {
positions = PickSampledPositions(
static_cast<int32_t>(chunk.frames.size()),
config_.position_sampling_rate(), chunk.use_count, gen);
} else {
auto probabilities =
FramesToProbabilities(chunk.frames, config_.position_sampling());
positions = SampleProbabilisticSequence(
config_.position_count() + config_.prefetch_count(),
config_.position_count() * chunk.use_count, probabilities, gen);
}
if (config_.has_prefetch_count()) {
// Prefetch mode: output first position to primary, rest to prefetch.
if (!positions.empty()) {
LoadMetricPauser pauser(context->load_metric_updater);
primary_producer.Put(std::move(chunk.frames[positions[0]]),
stop_token);
}
if (positions.size() > 1) {
CacheRequest cache_request;
cache_request.global_index = chunk.global_index;
cache_request.next_use = chunk.use_count + 1;
cache_request.items.reserve(positions.size() - 1);
for (size_t i = 1; i < positions.size(); ++i) {
cache_request.items.push_back(chunk.frames[positions[i]]);
}
LoadMetricPauser pauser(context->load_metric_updater);
prefetch_producer->Put(std::move(cache_request), stop_token);
}
} else {
// Normal mode: output all positions to primary.
for (uint32_t pos : positions) {
LoadMetricPauser pauser(context->load_metric_updater);
primary_producer.Put(std::move(chunk.frames[pos]), stop_token);
}
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "ChunkUnpacker worker stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "ChunkUnpacker worker stopping, request cancelled.";
}
}
StageMetricProto ChunkUnpacker::FlushMetrics() {
StageMetricProto stage_metric;
LoadMetricProto aggregated_load;
aggregated_load.set_name("load");
for (const auto& context : thread_contexts_) {
UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(aggregated_load);
*stage_metric.add_queue_metrics() =
MetricsFromQueue(config_.output().name(), primary_output_queue_);
if (prefetch_output_queue_.has_value()) {
*stage_metric.add_queue_metrics() = MetricsFromQueue(
config_.prefetch_output().name(), *prefetch_output_queue_);
}
return stage_metric;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_unpacker.h
================================================
// ABOUTME: Stage that unpacks chunks into FrameType frames.
// ABOUTME: Converts stream of std::string chunks to FrameType stream.
#pragma once
#include <atomic>
#include <cstdint>
#include <memory>
#include <stop_token>
#include <vector>
#include "absl/random/random.h"
#include "absl/types/optional.h"
#include "loader/data_loader_metrics.h"
#include "loader/frame_type.h"
#include "loader/stages/stage.h"
#include "loader/stages/training_chunk.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Worker pool that unpacks chunks into frames.
// Takes parsed TrainingChunk objects as input and outputs individual
// FrameType frames.
class ChunkUnpacker
: public SingleInputStage<ChunkUnpackerConfig, TrainingChunk> {
public:
using InputType = TrainingChunk;
explicit ChunkUnpacker(const ChunkUnpackerConfig& config);
~ChunkUnpacker();
void Start() override;
void Stop() override;
QueueBase* GetOutput(std::string_view name) override;
StageMetricProto FlushMetrics() override;
Queue<FrameType>* output_queue() { return &primary_output_queue_; }
private:
struct ThreadContext {
LoadMetricUpdater load_metric_updater;
};
void Worker(std::stop_token stop_token, ThreadContext* context);
const ChunkUnpackerConfig config_;
const uint32_t run_seed_;
Queue<FrameType> primary_output_queue_;
std::optional<Queue<CacheRequest>> prefetch_output_queue_;
// thread_contexts_ must be declared before thread_pool_ to ensure
// thread_pool_ is destroyed first (stopping threads before contexts).
std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;
ThreadPool thread_pool_;
};
std::vector<uint32_t> PickSampledPositions(int32_t n, double p,
int32_t iteration,
absl::BitGen& gen);
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/chunk_unpacker_test.cc
================================================
#include "loader/stages/chunk_unpacker.h"
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/random.h"
#include "absl/random/seed_sequences.h"
#include "gtest/gtest.h"
#include "libs/lc0/src/trainingdata/trainingdata_v6.h"
#include "loader/stages/training_chunk.h"
#include "proto/data_loader_config.pb.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
namespace {
template <typename T>
class PassthroughStage : public Stage {
public:
explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
void Start() override {}
void Stop() override {}
StageMetricProto FlushMetrics() override { return StageMetricProto(); }
QueueBase* GetOutput(std::string_view name = "") override {
(void)name;
return queue_;
}
void SetInputs(absl::Span<QueueBase* const> inputs) override {
if (!inputs.empty()) {
throw std::runtime_error("PassthroughStage expects no inputs");
}
}
private:
Queue<T>* queue_;
};
} // namespace
class ChunkUnpackerTest : public ::testing::Test {
protected:
void SetUp() override {
input_queue_ = std::make_unique<Queue<TrainingChunk>>(10);
config_.set_threads(1);
config_.mutable_output()->set_queue_capacity(10);
config_.set_position_sampling_rate(1.0f);
}
FrameType CreateTestFrame(uint32_t version) {
FrameType frame{};
frame.version = version;
frame.input_format = 3;
frame.root_q = 0.5f;
return frame;
}
TrainingChunk MakeChunk(std::vector<FrameType> frames,
std::string sort_key = "source", size_t index = 0,
uint32_t use = 0) {
TrainingChunk chunk;
chunk.sort_key = std::move(sort_key);
chunk.index_within_sort_key = index;
chunk.use_count = use;
chunk.frames = std::move(frames);
return chunk;
}
std::unique_ptr<Queue<TrainingChunk>> input_queue_;
ChunkUnpackerConfig config_;
};
TEST_F(ChunkUnpackerTest, UnpacksSingleFrame) {
ChunkUnpacker unpacker(config_);
unpacker.SetInputs({input_queue_.get()});
unpacker.Start();
FrameType test_frame = CreateTestFrame(6);
auto producer = input_queue_->CreateProducer();
producer.Put(MakeChunk({test_frame}));
producer.Close();
auto output_frame = unpacker.output_queue()->Get();
EXPECT_EQ(output_frame.version, 6);
EXPECT_EQ(output_frame.input_format, 3);
EXPECT_EQ(output_frame.root_q, 0.5f);
}
TEST_F(ChunkUnpackerTest, UnpacksMultipleFrames) {
ChunkUnpacker unpacker(config_);
unpacker.SetInputs({input_queue_.get()});
unpacker.Start();
std::vector<FrameType> test_frames = {CreateTestFrame(6), CreateTestFrame(7),
CreateTestFrame(8)};
auto producer = input_queue_->CreateProducer();
producer.Put(MakeChunk(test_frames));
producer.Close();
std::vector<uint32_t> actual_versions;
actual_versions.reserve(test_frames.size());
for (size_t i = 0; i < test_frames.size(); ++i) {
auto output_frame = unpacker.output_queue()->Get();
actual_versions.push_back(output_frame.version);
EXPECT_EQ(output_frame.input_format, 3);
EXPECT_EQ(output_frame.root_q, 0.5f);
}
std::vector<uint32_t> expected_versions;
expected_versions.reserve(test_frames.size());
for (const auto& frame : test_frames) {
expected_versions.push_back(frame.version);
}
absl::c_sort(actual_versions);
absl::c_sort(expected_versions);
EXPECT_EQ(actual_versions, expected_versions);
}
TEST_F(ChunkUnpackerTest, UnpacksMultipleChunks) {
ChunkUnpacker unpacker(config_);
unpacker.SetInputs({input_queue_.get()});
unpacker.Start();
auto producer = input_queue_->CreateProducer();
// Send first chunk with 2 frames
std::vector<FrameType> chunk1_frames = {CreateTestFrame(10),
CreateTestFrame(11)};
producer.Put(MakeChunk(chunk1_frames, "source", 0));
// Send second chunk with 1 frame
std::vector<FrameType> chunk2_frames = {CreateTestFrame(12)};
producer.Put(MakeChunk(chunk2_frames, "source", 1));
producer.Close();
// Verify all frames are output
std::vector<uint32_t> expected_versions = {10, 11, 12};
std::vector<uint32_t> actual_versions;
actual_versions.reserve(expected_versions.size());
for (size_t i = 0; i < expected_versions.size(); ++i) {
auto output_frame = unpacker.output_queue()->Get();
actual_versions.push_back(output_frame.version);
EXPECT_EQ(output_frame.input_format, 3);
EXPECT_EQ(output_frame.root_q, 0.5f);
}
absl::c_sort(actual_versions);
absl::c_sort(expected_versions);
EXPECT_EQ(actual_versions, expected_versions);
}
TEST_F(ChunkUnpackerTest, HandlesEmptyChunk) {
ChunkUnpacker unpacker(config_);
unpacker.SetInputs({input_queue_.get()});
unpacker.Start();
auto producer = input_queue_->CreateProducer();
TrainingChunk empty_chunk;
empty_chunk.sort_key = "source";
empty_chunk.index_within_sort_key = 0;
producer.Put(std::move(empty_chunk));
producer.Close();
// Should not produce any output frames, queue should close
EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException);
}
TEST_F(ChunkUnpackerTest, HandlesQueueClosure) {
ChunkUnpacker unpacker(config_);
unpacker.SetInputs({input_queue_.get()});
unpacker.Start();
// Close input queue without sending data
input_queue_->Close();
// Output queue should eventually close
EXPECT_THROW(unpacker.output_queue()->Get(), QueueClosedException);
}
TEST(PickSampledPositionsTest, Deterministic) {
absl::BitGen gen1(absl::SeedSeq{42});
std::vector<uint32_t> result1 = PickSampledPositions(1000, 0.1, 5, gen1);
absl::BitGen gen2(absl::SeedSeq{42});
std::vector<uint32_t> result2 = PickSampledPositions(1000, 0.1, 5, gen2);
EXPECT_EQ(result1, result2);
}
TEST(PickSampledPositionsTest, FullBucketFirstRound) {
absl::BitGen gen(absl::SeedSeq{42});
const uint32_t n = 10000;
const double p = 0.1;
std::vector<uint32_t> result = PickSampledPositions(n, p, 0, gen);
// Expect size to be around n*p.
EXPECT_NEAR(result.size(), n * p, n * p * 0.25);
}
TEST(PickSampledPositionsTest, DisjointBuckets) {
absl::BitGen gen(absl::SeedSeq{42});
const uint32_t n = 1000;
const double p = 0.1;
std::vector<uint32_t> bucket1 = PickSampledPositions(n, p, 0, gen);
absl::c_sort(bucket1);
// The generator state is now changed. For the next bucket, we need a fresh
// one with the same seed to test the logic for a different iteration.
absl::BitGen gen2(absl::SeedSeq{42});
std::vector<uint32_t> bucket2 = PickSampledPositions(n, p, 1, gen2);
absl::c_sort(bucket2);
std::vector<uint32_t> intersection;
absl::c_set_intersection(bucket1, bucket2, std::back_inserter(intersection));
EXPECT_TRUE(intersection.empty());
}
TEST(PickSampledPositionsTest, PartialBucketElementsAreReturned) {
absl::BitGen gen(absl::SeedSeq{42});
const uint32_t n = 1000;
const double p = 0.8; // remainder 0.2
// In round 1, elements with toss >= 0.8 are for iteration 1.
absl::BitGen gen1(absl::SeedSeq{42});
std::vector<uint32_t> expected_from_round1;
for (uint32_t i = 0; i < n; ++i) {
double toss = absl::Uniform<double>(gen1, 0.0, 1.0);
if (toss >= 0.8) {
expected_from_round1.push_back(i);
}
}
absl::c_sort(expected_from_round1);
std::vector<uint32_t> result = PickSampledPositions(n, p, 1, gen);
absl::c_sort(result);
// Check if all elements from round 1 are in the final result.
// This will fail with the current implementation because they are discarded.
std::vector<uint32_t> intersection;
absl::c_set_intersection(expected_from_round1, result,
std::back_inserter(intersection));
EXPECT_GT(expected_from_round1.size(), 50); // High probability for n=1000
EXPECT_EQ(intersection.size(), expected_from_round1.size());
}
TEST(PickSampledPositionsTest, PartialBucketCompletedSize) {
absl::BitGen gen(absl::SeedSeq{42});
const uint32_t n = 10000;
const double p = 0.8; // remainder 0.2
std::vector<uint32_t> result = PickSampledPositions(n, p, 1, gen);
// Expect size to be around n*p.
// This will fail due to incorrect probability calculation for completion.
EXPECT_NEAR(result.size(), n * p, n * p * 0.25);
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/file_path_provider.cc
================================================
#include "loader/stages/file_path_provider.h"
#include <absl/cleanup/cleanup.h>
#include <absl/container/flat_hash_set.h>
#include <absl/log/check.h>
#include <absl/log/log.h>
#include <absl/synchronization/mutex.h>
#include <sys/epoll.h>
#include <unistd.h>
#include <array>
#include <cerrno>
#include <chrono>
#include <cstring>
#include <filesystem>
#include <stdexcept>
#include <string_view>
#include <thread>
#include <utility>
#include "loader/data_loader_metrics.h"
#include "proto/data_loader_config.pb.h"
namespace lczero {
namespace training {
namespace {
bool ShouldSkipName(std::string_view name) {
return !name.empty() && name.front() == '.';
}
bool ShouldSkipPathEntry(const FilePathProvider::Path& path) {
return ShouldSkipName(path.filename().string());
}
} // namespace
FilePathProvider::FilePathProvider(const FilePathProviderConfig& config)
: SingleOutputStage<File>(config.output()),
directory_(config.directory()),
producer_(output_queue()->CreateProducer()),
load_metric_updater_() {
LOG(INFO) << "Initializing FilePathProvider for directory: "
<< config.directory();
inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
CHECK_NE(inotify_fd_, -1)
<< "Failed to initialize inotify: " << strerror(errno);
}
FilePathProvider::~FilePathProvider() {
LOG(INFO) << "FilePathProvider shutting down.";
Stop();
if (inotify_fd_ != -1) close(inotify_fd_);
LOG(INFO) << "FilePathProvider shutdown complete.";
}
void FilePathProvider::SetInputs(absl::Span<QueueBase* const> inputs) {
if (!inputs.empty()) {
throw std::runtime_error(
"FilePathProvider expects no inputs, but received " +
std::to_string(inputs.size()));
}
}
void FilePathProvider::Start() {
LOG(INFO) << "Starting FilePathProvider monitoring thread.";
thread_pool_.Enqueue(
[this](std::stop_token stop_token) { Worker(stop_token); });
}
void FilePathProvider::Stop() {
if (stop_source_.stop_requested()) return;
LOG(INFO) << "Stopping FilePathProvider.";
LOG(INFO) << "Stopping all watches...";
for (const auto& [wd, path] : watch_descriptors_) {
inotify_rm_watch(inotify_fd_, wd);
}
watch_descriptors_.clear();
stop_source_.request_stop();
thread_pool_.Shutdown();
producer_.Close();
}
StageMetricProto FilePathProvider::FlushMetrics() {
StageMetricProto stage_metric;
auto load_metrics = load_metric_updater_.FlushMetrics();
load_metrics.set_name("load");
*stage_metric.add_load_metrics() = std::move(load_metrics);
*stage_metric.add_queue_metrics() =
MetricsFromQueue("output", *output_queue());
return stage_metric;
}
void FilePathProvider::AddDirectory(const Path& directory,
std::stop_token stop_token) {
ScanDirectoryWithWatch(directory, stop_token);
LOG(INFO) << "FilePathProvider registered " << directory
<< "; active watch descriptors: " << watch_descriptors_.size();
// Signal that initial scan is complete
LOG(INFO) << "FilePathProvider initial scan complete";
producer_.Put(
{{.filepath = Path{}, .message_type = MessageType::kInitialScanComplete}},
stop_token);
}
void FilePathProvider::ScanDirectoryWithWatch(const Path& directory,
std::stop_token stop_token) {
// Step 1: Set up watch first
int wd = inotify_add_watch(inotify_fd_, directory.c_str(),
IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE |
IN_DELETE | IN_DELETE_SELF | IN_MOVE);
CHECK_NE(wd, -1) << "Failed to add inotify watch for " << directory << ": "
<< strerror(errno);
watch_descriptors_[wd] = directory;
// Step 2: Scan directory non-recursively, remembering files and subdirs
std::vector<Path> files;
std::vector<Path> subdirectories;
std::error_code ec;
auto iterator = std::filesystem::directory_iterator(directory, ec);
CHECK(!ec) << "Failed to iterate directory " << directory << ": "
<< ec.message();
for (const auto& entry : iterator) {
const Path entry_path = entry.path();
if (ShouldSkipPathEntry(entry_path)) continue;
if (entry.is_regular_file(ec) && !ec) {
files.push_back(entry_path);
} else if (entry.is_directory(ec) && !ec) {
subdirectories.push_back(entry_path);
}
}
const size_t initial_file_count = files.size();
const size_t subdirectory_count = subdirectories.size();
LOG(INFO) << "FilePathProvider scanned " << directory << " discovering "
<< initial_file_count << " file(s) and " << subdirectory_count
<< " subdirectory(ies) before watch reconciliation.";
// Send notifications for discovered files
constexpr size_t kBatchSize = 10000;
std::vector<File> batch;
batch.reserve(kBatchSize);
auto flush_batch = [&]() {
if (batch.empty()) return;
producer_.Put(batch, stop_token);
batch.clear();
};
for (const auto& filepath : files) {
batch.push_back(
{.filepath = filepath.string(), .message_type = MessageType::kFile});
if (batch.size() >= kBatchSize) flush_batch();
}
if (initial_file_count > 0) {
LOG(INFO) << "FilePathProvider enqueued " << initial_file_count
<< " file(s) from initial scan of " << directory;
}
// Step 3: Read from watch descriptor, skipping already discovered files
ProcessWatchEventsForNewItems(files);
// Step 4: Clean the files vector to save memory
files.clear();
// Step 5: Recursively call for subdirectories
for (const auto& subdir : subdirectories) {
if (stop_token.stop_requested()) return;
ScanDirectoryWithWatch(subdir, stop_token);
}
// Flush any remaining files
flush_batch();
}
void FilePathProvider::ProcessWatchEventsForNewItems(
const std::vector<Path>& known_files) {
// Create a set for fast lookup of already discovered files
absl::flat_hash_set<std::string> known_file_set;
for (const auto& file : known_files) {
known_file_set.insert(file.string());
}
// Process any events that may have occurred during scanning
std::array<char, 4096> buffer;
std::vector<File> new_files;
while (true) {
ssize_t length = read(inotify_fd_, buffer.data(), buffer.size());
if (length <= 0) break; // No more events to process
ssize_t offset = 0;
while (offset < length) {
const struct inotify_event* event =
reinterpret_cast<const struct inotify_event*>(buffer.data() + offset);
const bool skip_entry = event->len > 0 && ShouldSkipName(event->name);
// Only process file creation/write events, skip already known files
if ((event->mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 &&
event->len > 0 && !skip_entry) {
const Path directory(watch_descriptors_.at(event->wd));
Path filepath = directory / event->name;
std::string filepath_string = filepath.string();
// Only add if we haven't seen this file before
if (!known_file_set.contains(filepath_string)) {
new_files.push_back({.filepath = std::move(filepath_string),
.message_type = MessageType::kFile});
}
}
offset += sizeof(struct inotify_event) + event->len;
}
}
// Send notifications for any new files discovered through watch events
if (!new_files.empty()) {
LOG(INFO) << "FilePathProvider observed " << new_files.size()
<< " new file(s) while reconciling race events.";
producer_.Put(new_files);
}
}
void FilePathProvider::AddWatchRecursive(const Path& path) {
// Add watch for current directory
int wd = inotify_add_watch(inotify_fd_, path.c_str(),
IN_CLOSE_WRITE | IN_MOVED_TO | IN_CREATE |
IN_DELETE | IN_DELETE_SELF | IN_MOVE);
CHECK_NE(wd, -1) << "Failed to add inotify watch for " << path << ": "
<< strerror(errno);
watch_descriptors_[wd] = path;
// Recursively add watches for subdirectories
std::error_code ec;
auto iterator = std::filesystem::directory_iterator(path, ec);
CHECK(!ec) << "Failed to iterate directory " << path << ": " << ec.message();
for (const auto& entry : iterator) {
const Path entry_path = entry.path();
if (ShouldSkipPathEntry(entry_path)) continue;
if (!entry.is_directory(ec) || ec) continue;
AddWatchRecursive(entry_path);
}
}
void FilePathProvider::RemoveWatchRecursive(const Path& base) {
absl::erase_if(watch_descriptors_, [&](const auto& pair) {
const auto& [wd, path] = pair;
const auto mismatch_iter = absl::c_mismatch(base, path).first;
// If path is not a subdirectory (or equal) of base, skip.
if (mismatch_iter != base.end()) return false;
inotify_rm_watch(inotify_fd_, wd);
return true;
});
}
void FilePathProvider::Worker(std::stop_token stop_token) {
// Perform directory scanning in background thread
AddDirectory(directory_, stop_token);
int epoll_fd = epoll_create1(EPOLL_CLOEXEC);
CHECK_NE(epoll_fd, -1) << "Failed to create epoll fd: " << strerror(errno);
absl::Cleanup epoll_cleanup([epoll_fd]() { close(epoll_fd); });
struct epoll_event event;
event.events = EPOLLIN;
event.data.fd = inotify_fd_;
CHECK_EQ(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, inotify_fd_, &event), 0)
<< "Failed to add inotify fd to epoll: " << strerror(errno);
while (!stop_token.stop_requested()) {
{
LoadMetricPauser pauser(load_metric_updater_);
std::this_thread::sleep_for(std::chrono::milliseconds(50));
if (stop_token.stop_requested()) {
pauser.DoNotResume();
break;
}
}
struct epoll_event event;
int nfds = epoll_wait(epoll_fd, &event, 1, 0); // Non-blocking check
CHECK_NE(nfds, -1) << "epoll_wait failed: " << strerror(errno);
if (nfds == 0) continue; // No events.
do {
assert(nfds == 1 && event.data.fd == inotify_fd_);
ProcessInotifyEvents(producer_, stop_token);
nfds = epoll_wait(epoll_fd, &event, 1, 0);
} while (nfds > 0);
}
}
void FilePathProvider::ProcessInotifyEvents(Queue<File>::Producer& producer,
std::stop_token stop_token) {
constexpr size_t kNotifyBatchSize = 10000;
std::vector<File> files;
std::array<char, 4096> buffer;
auto flush_batch = [&]() {
if (files.empty()) return;
producer.Put(files, stop_token);
files.clear();
};
while (true) {
ssize_t length = read(inotify_fd_, buffer.data(), buffer.size());
if (length <= 0) break; // No more events to process
ssize_t offset = 0;
while (offset < length) {
const struct inotify_event* event =
reinterpret_cast<const struct inotify_event*>(buffer.data() + offset);
auto file = ProcessInotifyEvent(*event, stop_token);
if (file) files.push_back(*file);
if (files.size() >= kNotifyBatchSize) flush_batch();
offset += sizeof(struct inotify_event) + event->len;
}
}
flush_batch(); // Flush any remaining files in the batch
}
auto FilePathProvider::ProcessInotifyEvent(const struct inotify_event& event,
std::stop_token stop_token)
-> std::optional<File> {
if (event.mask & IN_IGNORED) return std::nullopt;
const Path directory(watch_descriptors_.at(event.wd));
const bool has_name = event.len > 0 && event.name[0] != '\0';
const bool skip_entry = has_name && ShouldSkipName(event.name);
const Path filepath = has_name ? directory / event.name : directory;
// Handle different event types
if ((event.mask & (IN_CLOSE_WRITE | IN_MOVED_TO)) != 0 && has_name &&
!skip_entry) {
// File finished writing or moved into directory
return File{.filepath = filepath, .message_type = MessageType::kFile};
}
constexpr uint32_t kDirCreateMask = IN_CREATE | IN_ISDIR;
constexpr uint32_t kDirDeleteMask = IN_DELETE | IN_ISDIR;
if ((event.mask & kDirCreateMask) == kDirCreateMask) {
if (!has_name || skip_entry) return std::nullopt;
ScanDirectoryWithWatch(filepath, stop_token);
} else if ((event.mask & kDirDeleteMask) == kDirDeleteMask) {
if (!has_name || skip_entry) return std::nullopt;
// Directory deleted - remove all watches for it and subdirectories
RemoveWatchRecursive(filepath);
} else if (event.mask & IN_DELETE_SELF) {
RemoveWatchRecursive(directory);
}
return std::nullopt;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/file_path_provider.h
================================================
#pragma once
#include <absl/base/thread_annotations.h>
#include <absl/container/flat_hash_map.h>
#include <absl/log/log.h>
#include <absl/synchronization/mutex.h>
#include <sys/inotify.h>
#include <filesystem>
#include <functional>
#include <span>
#include <stop_token>
#include <string>
#include <vector>
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
#include "utils/metrics/printer.h"
#include "utils/metrics/statistics_metric.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Message types for FilePathProvider output.
enum class FilePathProviderMessageType {
kFile, // File discovered (initial scan or inotify)
kInitialScanComplete // Initial scan is complete (empty filepath)
};
// Output type for FilePathProvider.
struct FilePathProviderFile {
std::filesystem::path filepath;
FilePathProviderMessageType message_type;
};
// This class watches for new files in a directory (recursively) and notifies
// registered observers when new files are either closed after writing or
// renamed into.
// Uses background thread to monitor the directory.
class FilePathProvider : public SingleOutputStage<FilePathProviderFile> {
public:
using Path = std::filesystem::path;
using MessageType = FilePathProviderMessageType;
using File = FilePathProviderFile;
explicit FilePathProvider(const FilePathProviderConfig& config);
~FilePathProvider();
// Starts monitoring the directory
void Start() override;
// Closes the output queue, signaling completion
void Stop() override;
// Returns current metrics and clears them.
StageMetricProto FlushMetrics() override;
// FilePathProvider has no inputs.
void SetInputs(absl::Span<QueueBase* const> inputs) override;
private:
// Starts monitoring the directory.
void AddDirectory(const Path& directory, std::stop_token stop_token);
void Worker(std::stop_token stop_token);
void AddWatchRecursive(const Path& path);
void RemoveWatchRecursive(const Path& path);
void ScanDirectoryWithWatch(const Path& directory,
std::stop_token stop_token);
void ProcessWatchEventsForNewItems(const std::vector<Path>& known_files);
void ProcessInotifyEvents(Queue<File>::Producer& producer,
std::stop_token stop_token);
std::optional<File> ProcessInotifyEvent(const struct inotify_event& event,
std::stop_token stop_token);
int inotify_fd_;
// Watch descriptor to directory path.
absl::flat_hash_map<int, Path> watch_descriptors_;
Path directory_; // Directory to monitor
Queue<File>::Producer producer_;
LoadMetricUpdater load_metric_updater_;
std::stop_source stop_source_;
ThreadPool thread_pool_{1, ThreadPoolOptions{}, stop_source_};
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/file_path_provider_main.cc
================================================
#include <absl/flags/flag.h>
#include <absl/flags/parse.h>
#include <absl/log/globals.h>
#include <absl/log/initialize.h>
#include <absl/log/log.h>
#include <iostream>
#include <string>
#include <thread>
#include "loader/stages/file_path_provider.h"
ABSL_FLAG(std::string, directory, "", "Directory to monitor for files");
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
absl::InitializeLog();
absl::SetStderrThreshold(absl::LogSeverityAtLeast::kInfo);
std::string directory = absl::GetFlag(FLAGS_directory);
if (directory.empty()) {
std::cerr << "Usage: " << argv[0] << " --directory=<directory>"
<< std::endl;
return 1;
}
LOG(INFO) << "Starting to monitor directory: " << directory;
lczero::training::FilePathProviderConfig config;
config.mutable_output()->set_queue_capacity(16);
config.set_directory(directory);
lczero::training::FilePathProvider file_path_provider(config);
// Consumer thread to read from the queue
std::thread consumer_thread([&file_path_provider]() {
auto* queue = file_path_provider.output_queue();
try {
while (true) {
auto file = queue->Get();
const char* type_str =
(file.message_type ==
lczero::training::FilePathProvider::MessageType::kFile)
? "File"
: "Initial scan complete";
LOG(INFO) << "File " << type_str << ": " << file.filepath;
}
} catch (const lczero::QueueClosedException&) {
LOG(INFO) << "Queue closed, consumer thread exiting";
}
});
LOG(INFO) << "Monitoring for files... Press Enter to exit.";
std::cin.get();
// Close the queue and wait for consumer to finish
file_path_provider.Stop();
consumer_thread.join();
return 0;
}
================================================
FILE: csrc/loader/stages/file_path_provider_test.cc
================================================
#include "loader/stages/file_path_provider.h"
#include <gtest/gtest.h>
#include <chrono>
#include <filesystem>
#include <fstream>
#include <string>
#include <unordered_set>
#include <vector>
namespace lczero {
namespace training {
namespace {
FilePathProviderConfig MakeConfig(const std::filesystem::path& directory) {
FilePathProviderConfig config;
config.mutable_output()->set_queue_capacity(128);
config.set_directory(directory.string());
return config;
}
std::string RelativeTo(const std::filesystem::path& base,
const std::filesystem::path& target) {
return target.lexically_relative(base).generic_string();
}
} // namespace
class FilePathProviderTest : public ::testing::Test {
protected:
void SetUp() override {
test_dir_ =
std::filesystem::temp_directory_path() /
("file_path_provider_test_" +
std::to_string(
std::chrono::steady_clock::now().time_since_epoch().count()));
std::filesystem::create_directories(test_dir_);
}
void TearDown() override {
if (std::filesystem::exists(test_dir_)) {
std::filesystem::remove_all(test_dir_);
}
}
void CreateFile(const std::filesystem::path& path,
const std::string& content = "payload") {
std::filesystem::create_directories(path.parent_path());
std::ofstream file(path);
file << content;
}
void CreateDirectory(const std::filesystem::path& path) {
std::filesystem::create_directories(path);
}
std::vector<std::filesystem::path> DrainInitialScan(
Queue<FilePathProvider::File>* queue) {
std::vector<std::filesystem::path> files;
while (true) {
auto message = queue->Get();
if (message.message_type ==
FilePathProvider::MessageType::kInitialScanComplete) {
EXPECT_TRUE(message.filepath.empty());
break;
}
if (message.message_type != FilePathProvider::MessageType::kFile) {
ADD_FAILURE() << "Unexpected message type in initial scan.";
continue;
}
files.push_back(message.filepath);
}
return files;
}
FilePathProvider::File AwaitNextFile(Queue<FilePathProvider::File>* queue) {
while (true) {
auto message = queue->Get();
if (message.message_type == FilePathProvider::MessageType::kFile) {
return message;
}
if (message.message_type !=
FilePathProvider::MessageType::kInitialScanComplete) {
ADD_FAILURE()
<< "Unexpected message type while waiting for file notification.";
}
}
}
FilePathProviderConfig Config() const { return MakeConfig(test_dir_); }
std::filesystem::path test_dir_;
};
TEST_F(FilePathProviderTest, ConstructorCreatesQueue) {
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
ASSERT_NE(queue, nullptr);
EXPECT_EQ(queue->Capacity(), 128);
auto message = queue->Get();
EXPECT_EQ(message.message_type,
FilePathProvider::MessageType::kInitialScanComplete);
EXPECT_TRUE(message.filepath.empty());
provider.Stop();
}
TEST_F(FilePathProviderTest, InitialScanFindsVisibleFiles) {
CreateFile(test_dir_ / "file1.txt");
CreateFile(test_dir_ / "file2.txt");
CreateFile(test_dir_ / "sub" / "nested.txt");
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
auto discovered = DrainInitialScan(queue);
std::unordered_set<std::string> relative_paths;
for (const auto& path : discovered) {
relative_paths.insert(RelativeTo(test_dir_, path));
}
EXPECT_EQ(relative_paths.size(), 3u);
EXPECT_TRUE(relative_paths.count("file1.txt"));
EXPECT_TRUE(relative_paths.count("file2.txt"));
EXPECT_TRUE(relative_paths.count("sub/nested.txt"));
provider.Stop();
}
TEST_F(FilePathProviderTest, InitialScanSkipsHiddenEntries) {
CreateFile(test_dir_ / "visible.txt");
CreateFile(test_dir_ / ".hidden_file");
CreateFile(test_dir_ / ".hidden_dir" / "nested.txt");
CreateFile(test_dir_ / "visible_dir" / "child.txt");
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
auto discovered = DrainInitialScan(queue);
std::unordered_set<std::string> relative_paths;
for (const auto& path : discovered) {
relative_paths.insert(RelativeTo(test_dir_, path));
}
EXPECT_TRUE(relative_paths.count("visible.txt"));
EXPECT_TRUE(relative_paths.count("visible_dir/child.txt"));
EXPECT_FALSE(relative_paths.count(".hidden_file"));
EXPECT_FALSE(relative_paths.count(".hidden_dir/nested.txt"));
provider.Stop();
}
TEST_F(FilePathProviderTest, DetectsNewVisibleFile) {
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
DrainInitialScan(queue);
CreateFile(test_dir_ / "new_file.txt");
auto message = AwaitNextFile(queue);
EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "new_file.txt");
provider.Stop();
}
TEST_F(FilePathProviderTest, DetectsFilesInPreExistingSubdirectory) {
auto subdir = test_dir_ / "subdir";
CreateDirectory(subdir);
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
DrainInitialScan(queue);
CreateFile(subdir / "from_subdir.txt");
auto message = AwaitNextFile(queue);
EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "subdir/from_subdir.txt");
provider.Stop();
}
TEST_F(FilePathProviderTest, IgnoresHiddenFileEvents) {
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
DrainInitialScan(queue);
CreateFile(test_dir_ / ".hidden_event.txt");
CreateFile(test_dir_ / "visible_after_hidden.txt");
auto message = AwaitNextFile(queue);
EXPECT_EQ(RelativeTo(test_dir_, message.filepath),
"visible_after_hidden.txt");
provider.Stop();
}
TEST_F(FilePathProviderTest, SkipsHiddenDirectoryRecursion) {
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
DrainInitialScan(queue);
CreateDirectory(test_dir_ / ".hidden_dir");
CreateFile(test_dir_ / ".hidden_dir" / "inner.txt");
CreateFile(test_dir_ / "outer.txt");
auto message = AwaitNextFile(queue);
EXPECT_EQ(RelativeTo(test_dir_, message.filepath), "outer.txt");
provider.Stop();
}
TEST_F(FilePathProviderTest, HandlesEmptyDirectory) {
FilePathProvider provider(Config());
provider.Start();
auto* queue = provider.output_queue();
auto discovered = DrainInitialScan(queue);
EXPECT_TRUE(discovered.empty());
provider.Stop();
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/join_stage.cc
================================================
#include "loader/stages/join_stage.h"
#include <absl/log/log.h>
#include "loader/data_loader_metrics.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
namespace lczero {
namespace training {
template <typename T>
JoinStage<T>::JoinStage(const JoinPositionsConfig& config)
: SingleOutputStage<T>(config.output()) {}
template <typename T>
JoinStage<T>::~JoinStage() {
Stop();
}
template <typename T>
void JoinStage<T>::SetInputs(absl::Span<QueueBase* const> inputs) {
input_queues_.clear();
for (QueueBase* base_queue : inputs) {
auto* typed_queue = dynamic_cast<Queue<T>*>(base_queue);
if (!typed_queue) throw std::runtime_error("Input queue type mismatch");
input_queues_.push_back(typed_queue);
}
}
template <typename T>
void JoinStage<T>::Start() {
thread_contexts_.clear();
thread_pool_ = std::make_unique<ThreadPool>(input_queues_.size());
for (size_t i = 0; i < input_queues_.size(); ++i) {
thread_contexts_.push_back(std::make_unique<ThreadContext>());
}
for (size_t i = 0; i < input_queues_.size(); ++i) {
thread_pool_->Enqueue([this, i](std::stop_token stop_token) {
Worker(stop_token, input_queues_[i], thread_contexts_[i].get());
});
}
}
template <typename T>
void JoinStage<T>::Worker(std::stop_token stop_token, Queue<T>* input_queue,
ThreadContext* context) {
auto producer = this->output_queue()->CreateProducer();
try {
while (true) {
auto item = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return input_queue->Get(stop_token);
}();
producer.Put(std::move(item), stop_token);
}
} catch (const QueueClosedException&) {
}
}
template <typename T>
void JoinStage<T>::Stop() {
if (!thread_pool_ || thread_pool_->stop_token().stop_requested()) return;
LOG(INFO) << "Stopping JoinStage.";
thread_pool_->Shutdown();
this->output_queue()->Close();
}
template <typename T>
StageMetricProto JoinStage<T>::FlushMetrics() {
StageMetricProto metrics;
LoadMetricProto aggregated_load;
aggregated_load.set_name("load");
for (const auto& context : thread_contexts_) {
UpdateFrom(aggregated_load, context->load_metric_updater.FlushMetrics());
}
*metrics.add_load_metrics() = std::move(aggregated_load);
*metrics.add_queue_metrics() =
MetricsFromQueue("output", *this->output_queue());
return metrics;
}
// Explicit template instantiation for FrameType.
template class JoinStage<FrameType>;
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/join_stage.h
================================================
// ABOUTME: Stage that joins multiple input queues into a single output.
// ABOUTME: Spawns one thread per input to read and forward items.
#pragma once
#include <atomic>
#include <memory>
#include <stop_token>
#include <vector>
#include "loader/data_loader_metrics.h"
#include "loader/frame_type.h"
#include "loader/stages/stage.h"
#include "proto/data_loader_config.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/queue.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
// Template stage that joins multiple input queues into a single output.
// Spawns one thread per input to consume and forward items.
template <typename T>
class JoinStage : public SingleOutputStage<T> {
public:
using OutputType = T;
explicit JoinStage(const JoinPositionsConfig& config);
~JoinStage();
void Start() override;
void Stop() override;
StageMetricProto FlushMetrics() override;
void SetInputs(absl::Span<QueueBase* const> inputs) override;
private:
struct ThreadContext {
LoadMetricUpdater load_metric_updater;
};
void Worker(std::stop_token stop_token, Queue<T>* input_queue,
ThreadContext* context);
std::vector<Queue<T>*> input_queues_;
// thread_contexts_ must be declared before thread_pool_ to ensure
// thread_pool_ is destroyed first (stopping threads before contexts).
std::vector<std::unique_ptr<ThreadContext>> thread_contexts_;
std::unique_ptr<ThreadPool> thread_pool_;
};
using JoinPositions = JoinStage<FrameType>;
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/join_stage_test.cc
================================================
#include "loader/stages/join_stage.h"
#include <memory>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "gtest/gtest.h"
#include "libs/lc0/src/trainingdata/trainingdata_v6.h"
#include "proto/data_loader_config.pb.h"
#include "utils/queue.h"
namespace lczero {
namespace training {
class JoinStageTest : public ::testing::Test {
protected:
void SetUp() override { config_.mutable_output()->set_queue_capacity(100); }
FrameType CreateTestFrame(uint32_t version) {
FrameType frame{};
frame.version = version;
frame.input_format = 3;
frame.root_q = 0.5f;
return frame;
}
JoinPositionsConfig config_;
};
TEST_F(JoinStageTest, JoinsTwoInputs) {
auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);
auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);
JoinPositions join_stage(config_);
join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()});
join_stage.Start();
auto producer_1 = input_queue_1->CreateProducer();
auto producer_2 = input_queue_2->CreateProducer();
producer_1.Put(CreateTestFrame(1));
producer_1.Put(CreateTestFrame(2));
producer_2.Put(CreateTestFrame(3));
producer_2.Put(CreateTestFrame(4));
absl::flat_hash_set<uint32_t> received_versions;
for (int i = 0; i < 4; ++i) {
auto frame = join_stage.output_queue()->Get();
received_versions.insert(frame.version);
}
producer_1.Close();
producer_2.Close();
join_stage.Stop();
EXPECT_EQ(received_versions.size(), 4u);
EXPECT_TRUE(received_versions.contains(1));
EXPECT_TRUE(received_versions.contains(2));
EXPECT_TRUE(received_versions.contains(3));
EXPECT_TRUE(received_versions.contains(4));
}
TEST_F(JoinStageTest, JoinsThreeInputs) {
auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);
auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);
auto input_queue_3 = std::make_unique<Queue<FrameType>>(10);
JoinPositions join_stage(config_);
join_stage.SetInputs(
{input_queue_1.get(), input_queue_2.get(), input_queue_3.get()});
join_stage.Start();
auto producer_1 = input_queue_1->CreateProducer();
auto producer_2 = input_queue_2->CreateProducer();
auto producer_3 = input_queue_3->CreateProducer();
producer_1.Put(CreateTestFrame(10));
producer_2.Put(CreateTestFrame(20));
producer_3.Put(CreateTestFrame(30));
absl::flat_hash_set<uint32_t> received_versions;
for (int i = 0; i < 3; ++i) {
auto frame = join_stage.output_queue()->Get();
received_versions.insert(frame.version);
}
producer_1.Close();
producer_2.Close();
producer_3.Close();
join_stage.Stop();
EXPECT_EQ(received_versions.size(), 3u);
EXPECT_TRUE(received_versions.contains(10));
EXPECT_TRUE(received_versions.contains(20));
EXPECT_TRUE(received_versions.contains(30));
}
TEST_F(JoinStageTest, HandlesEmptyInputs) {
auto input_queue_1 = std::make_unique<Queue<FrameType>>(10);
auto input_queue_2 = std::make_unique<Queue<FrameType>>(10);
JoinPositions join_stage(config_);
join_stage.SetInputs({input_queue_1.get(), input_queue_2.get()});
join_stage.Start();
auto producer_1 = input_queue_1->CreateProducer();
auto producer_2 = input_queue_2->CreateProducer();
producer_1.Close();
producer_2.Close();
auto maybe_frame = join_stage.output_queue()->MaybeGet();
EXPECT_FALSE(maybe_frame.has_value());
join_stage.Stop();
}
TEST_F(JoinStageTest, FlushesMetrics) {
auto input_queue = std::make_unique<Queue<FrameType>>(10);
JoinPositions join_stage(config_);
join_stage.SetInputs({input_queue.get()});
join_stage.Start();
auto producer = input_queue->CreateProducer();
producer.Put(CreateTestFrame(1));
auto frame = join_stage.output_queue()->Get();
EXPECT_EQ(frame.version, 1u);
producer.Close();
join_stage.Stop();
auto metrics = join_stage.FlushMetrics();
EXPECT_EQ(metrics.load_metrics_size(), 1);
EXPECT_EQ(metrics.queue_metrics_size(), 1);
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/position_sampling.cc
================================================
#include "loader/stages/position_sampling.h"
#include <cmath>
namespace lczero {
namespace training {
float ComputePositionSamplingWeight(const FrameType& frame,
const PositionSamplingConfig& config) {
if (!config.has_diff_focus_q_weight() && !config.has_diff_focus_pol_scale()) {
return config.default_weight();
}
if (std::isnan(frame.orig_q)) return config.default_weight();
const float diff_q = std::abs(frame.best_q - frame.orig_q);
const float q_weight = config.diff_focus_q_weight();
const float pol_scale = config.diff_focus_pol_scale();
const float total =
(q_weight * diff_q + frame.policy_kld) / (q_weight + pol_scale);
return std::min(
std::pow(total * config.diff_focus_alpha() + config.diff_focus_beta(),
config.diff_focus_gamma()),
config.diff_focus_tau());
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/position_sampling.h
================================================
#pragma once
#include "loader/frame_type.h"
#include "proto/data_loader_config.pb.h"
namespace lczero {
namespace training {
float ComputePositionSamplingWeight(const FrameType& frame,
const PositionSamplingConfig& config);
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/shuffling_chunk_pool.cc
================================================
#include "loader/stages/shuffling_chunk_pool.h"
#include <absl/algorithm/container.h>
#include <absl/base/thread_annotations.h>
#include <absl/log/log.h>
#include <absl/random/random.h>
#include <absl/synchronization/mutex.h>
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <cstring>
#include <filesystem>
#include <limits>
#include <stdexcept>
#include <thread>
#include <utility>
#include "loader/chunk_source/chunk_source.h"
#include "loader/data_loader_metrics.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/position_sampling.h"
#include "proto/data_loader_config.pb.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
thread_local absl::BitGen ShufflingChunkPool::bitgen_{absl::MakeSeedSeq()};
ShufflingChunkPool::ShufflingChunkPool(const ShufflingChunkPoolConfig& config)
: primary_output_name_(config.output().name()),
primary_output_queue_(
config.output().queue_capacity(),
ToOverflowBehavior(config.output().overflow_behavior())),
chunk_pool_size_(config.chunk_pool_size()),
config_(config),
source_ingestion_pool_(config.source_ingestion_threads(),
ThreadPoolOptions{}, stop_source_),
chunk_loading_pool_(config.chunk_loading_threads(), ThreadPoolOptions{},
stop_source_),
caching_pool_(config.has_cachehit_output() ? config.caching_threads() : 0,
ThreadPoolOptions{}, stop_source_) {
if (config.has_cachehit_output()) {
cachehit_output_name_ = config.cachehit_output().name();
cachehit_output_queue_.emplace(
config.cachehit_output().queue_capacity(),
ToOverflowBehavior(config.cachehit_output().overflow_behavior()));
if (primary_output_name_ == *cachehit_output_name_) {
throw std::runtime_error(absl::StrCat(
"ShufflingChunkPool output names must be different, got: '",
primary_output_name_, "'"));
}
}
LOG(INFO) << "Initializing ShufflingChunkPool with pool size "
<< config.chunk_pool_size();
}
ShufflingChunkPool::~ShufflingChunkPool() { Stop(); }
void ShufflingChunkPool::SetInputs(absl::Span<QueueBase* const> inputs) {
if (inputs.size() != 1 && inputs.size() != 2) {
throw std::runtime_error(absl::StrCat(
"ShufflingChunkPool expects 1 or 2 inputs, got ", inputs.size()));
}
if (inputs.size() == 2 && !cachehit_output_queue_.has_value()) {
throw std::runtime_error(
"ShufflingChunkPool received 2 inputs but cachehit_output is not "
"configured");
}
if (inputs.size() == 1 && cachehit_output_queue_.has_value()) {
throw std::runtime_error(
"ShufflingChunkPool has cachehit_output configured but received only "
"1 input");
}
primary_input_queue_ = dynamic_cast<Queue<ChunkSourceWithPhase>*>(inputs[0]);
if (!primary_input_queue_) {
throw std::runtime_error("ShufflingChunkPool primary input type mismatch");
}
if (inputs.size() == 2) {
cache_request_queue_ = dynamic_cast<Queue<CacheRequest>*>(inputs[1]);
if (!cache_request_queue_) {
throw std::runtime_error(
"ShufflingChunkPool cache request input type mismatch");
}
}
}
QueueBase* ShufflingChunkPool::GetOutput(std::string_view name) {
if (name == primary_output_name_) return &primary_output_queue_;
if (cachehit_output_name_.has_value() && name == *cachehit_output_name_) {
return &*cachehit_output_queue_;
}
std::string available = absl::StrCat("'", primary_output_name_, "'");
if (cachehit_output_name_.has_value()) {
absl::StrAppend(&available, ", '", *cachehit_output_name_, "'");
}
throw std::runtime_error(absl::StrCat("ShufflingChunkPool unknown output '",
name,
"'. Available outputs: ", available));
}
void ShufflingChunkPool::Start() {
LOG(INFO) << "Starting ShufflingChunkPool initialization thread.";
initialization_thread_ = std::jthread([this]() {
try {
LOG(INFO) << "Starting ShufflingChunkPool with pool size "
<< config_.chunk_pool_size();
std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources =
InitializeChunkSources();
ProcessInputFiles(std::move(uninitialized_sources));
// Start input processing worker that continuously processes new files.
for (size_t i = 0; i < source_ingestion_pool_.num_threads(); ++i) {
auto* context =
source_ingestion_thread_contexts_
.emplace_back(std::make_unique<SourceIngestionThreadContext>())
.get();
source_ingestion_pool_.Enqueue(
[this, context](std::stop_token stop_token) {
SourceIngestionWorker(stop_token, context);
});
}
// Start output workers after everything is fully initialized.
LOG(INFO) << "ShufflingChunkPool initialization done, starting workers";
for (size_t i = 0; i < chunk_loading_pool_.num_threads(); ++i) {
auto* context =
chunk_loading_thread_contexts_
.emplace_back(std::make_unique<ChunkLoadingThreadContext>())
.get();
chunk_loading_pool_.Enqueue(
[this, context](std::stop_token stop_token) {
OutputWorker(stop_token, context);
});
}
// Start caching workers if configured.
if (cachehit_output_queue_.has_value()) {
for (size_t i = 0; i < caching_pool_.num_threads(); ++i) {
auto* context =
caching_thread_contexts_
.emplace_back(std::make_unique<CachingThreadContext>())
.get();
caching_pool_.Enqueue([this, context](std::stop_token stop_token) {
CachingWorker(stop_token, context);
});
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "ShufflingChunkPool initialization interrupted, input "
"queue closed.";
output_queue()->Close();
} catch (const std::exception& e) {
LOG(ERROR) << "ShufflingChunkPool initialization failed: " << e.what();
output_queue()->Close();
}
});
}
void ShufflingChunkPool::Stop() {
if (stop_source_.stop_requested()) return;
LOG(INFO) << "Stopping ShufflingChunkPool.";
stop_source_.request_stop();
if (initialization_thread_.joinable()) {
initialization_thread_.request_stop();
initialization_thread_.join();
}
source_ingestion_pool_.Shutdown();
chunk_loading_pool_.Shutdown();
if (cachehit_output_queue_) caching_pool_.Shutdown();
output_queue()->Close();
if (cachehit_output_queue_) cachehit_output_queue_->Close();
LOG(INFO) << "ShufflingChunkPool stopped.";
}
std::vector<std::unique_ptr<ChunkSource>>
ShufflingChunkPool::InitializeChunkSources() {
std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources;
// Read from input queue until kInitialScanComplete.
while (true) {
auto chunk_source_with_phase = input_queue()->Get();
if (chunk_source_with_phase.message_type ==
FilePathProvider::MessageType::kInitialScanComplete) {
LOG(INFO)
<< "ShufflingChunkPool received initial scan completion marker.";
break;
}
if (chunk_source_with_phase.message_type ==
FilePathProvider::MessageType::kFile) {
// Add ChunkSource to uninitialized sources.
uninitialized_sources.push_back(
std::move(chunk_source_with_phase.source));
}
}
LOG(INFO) << "ShufflingChunkPool initial directory walk produced "
<< uninitialized_sources.size() << " chunk source candidate(s).";
// Sort in descending order (newest first).
std::sort(uninitialized_sources.begin(), uninitialized_sources.end(),
[](const auto& a, const auto& b) {
return a->GetChunkSortKey() > b->GetChunkSortKey();
});
std::atomic<size_t> total_chunks = 0;
size_t sources_to_keep = 0;
// Process sources sequentially until we have enough chunks.
std::string current_anchor;
{
absl::MutexLock lock(&anchor_mutex_);
current_anchor = anchor_;
}
for (auto& source : uninitialized_sources) {
if (output_queue()->IsClosed()) {
LOG(INFO) << "Output queue closed, stopping source ingestion.";
break;
}
if (total_chunks >= chunk_pool_size_) break;
// Count chunks immediately; constructors have already prepared metadata.
const size_t chunk_count = source->GetChunkCount();
total_chunks += chunk_count;
// Count chunks since anchor during initial load.
if (source->GetChunkSortKey() > current_anchor) {
chunks_since_anchor_ += chunk_count;
}
LOG_EVERY_N_SEC(INFO, 4) << "Loaded so far: " << total_chunks.load()
<< "; new: " << chunks_since_anchor_;
++sources_to_keep;
}
LOG(INFO) << "ShufflingChunkPool indexed " << total_chunks.load()
<< " chunk(s) across " << sources_to_keep
<< " source(s) during startup.";
if (total_chunks < chunk_pool_size_ && !output_queue()->IsClosed()) {
LOG(ERROR) << "ShufflingChunkPool startup chunk requirement not met: "
<< total_chunks.load() << " < " << chunk_pool_size_;
}
// Trim the vector to only keep the sources we need.
uninitialized_sources.resize(sources_to_keep);
return uninitialized_sources;
}
void ShufflingChunkPool::ProcessInputFiles(
std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources) {
// Initialize chunk sources from the initial scan.
size_t initial_window_sources = 0;
size_t initial_total_chunks = 0;
{
absl::MutexLock lock(&chunk_sources_mutex_);
size_t start_chunk_index = 0;
// Newest sources first, so we add in reverse order.
std::for_each(uninitialized_sources.rbegin(), uninitialized_sources.rend(),
[this, &start_chunk_index](auto& source) {
const size_t count = source->GetChunkCount();
auto item = std::make_shared<ChunkSourceItem>();
item->start_chunk_index = start_chunk_index;
item->source = std::move(source);
item->use_counts = std::vector<uint16_t>(count, 0);
item->weight = std::vector<float>(count, -1.0f);
item->cache = std::vector<std::unique_ptr<CacheNode>>(
cachehit_output_queue_.has_value() ? count : 0);
chunk_sources_.push_back(std::move(item));
start_chunk_index +=
chunk_sources_.back()->source->GetChunkCount();
});
// Initialize stream shuffler with the initial bounds.
if (!chunk_sources_.empty()) {
size_t total_chunks = chunk_sources_.back()->start_chunk_index +
chunk_sources_.back()->source->GetChunkCount();
// Set bounds to provide the last chunk_pool_size_ chunks.
size_t lower_bound =
total_chunks > chunk_pool_size_ ? total_chunks - chunk_pool_size_ : 0;
stream_shuffler_.SetLowerBound(lower_bound);
stream_shuffler_.SetUpperBound(total_chunks);
initial_total_chunks = total_chunks;
}
initial_window_sources = chunk_sources_.size();
}
LOG(INFO) << "ShufflingChunkPool initial window ready with "
<< initial_window_sources << " source(s) totaling "
<< initial_total_chunks << " chunk(s).";
// Log anchor and sources after initial scan completion.
{
absl::MutexLock anchor_lock(&anchor_mutex_);
LOG(INFO) << "Current anchor: '" << anchor_ << "'";
absl::MutexLock sources_lock(&chunk_sources_mutex_);
std::vector<std::shared_ptr<ChunkSourceItem>> sources_after_anchor;
for (const auto& item : chunk_sources_) {
if (item->source->GetChunkSortKey() > anchor_) {
sources_after_anchor.push_back(item);
}
}
LOG(INFO) << sources_after_anchor.size()
<< " chunk source(s) after anchor, " << chunks_since_anchor_
<< " total chunks since anchor";
const size_t to_log = std::min(sources_after_anchor.size(), size_t(20));
for (size_t i = 0; i < to_log; ++i) {
LOG(INFO) << " Source [" << (i + 1) << "/" << sources_after_anchor.size()
<< "]: key='"
<< sources_after_anchor[i]->source->GetChunkSortKey()
<< "', chunks="
<< sources_after_anchor[i]->source->GetChunkCount();
}
}
if (initial_total_chunks == 0) {
throw std::runtime_error(
"ShufflingChunkPool requires at least one chunk during startup.");
}
}
void ShufflingChunkPool::SourceIngestionWorker(
std::stop_token stop_token, SourceIngestionThreadContext* context) {
try {
while (true) {
auto chunk_source_with_phase = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return input_queue()->Get(stop_token);
}();
if (chunk_source_with_phase.message_type ==
FilePathProvider::MessageType::kFile) {
// Ingest the new chunk source.
auto source = std::move(chunk_source_with_phase.source);
size_t chunk_count = source->GetChunkCount();
absl::MutexLock lock(&chunk_sources_mutex_);
chunks_since_anchor_ += chunk_count;
AddNewChunkSource(std::move(source));
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "SourceIngestionWorker stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "SourceIngestionWorker stopping, request cancelled.";
}
}
void ShufflingChunkPool::OutputWorker(std::stop_token stop_token,
ChunkLoadingThreadContext* context) {
// Create a local producer for this worker
auto primary_producer = output_queue()->CreateProducer();
std::optional<decltype(cachehit_output_queue_->CreateProducer())>
cachehit_producer;
if (cachehit_output_queue_.has_value()) {
cachehit_producer.emplace(cachehit_output_queue_->CreateProducer());
}
try {
while (true) {
auto result = GetNextChunkData();
if (!result) {
if (output_queue()->IsClosed()) break;
continue;
}
LoadMetricPauser pauser(context->load_metric_updater);
if (std::holds_alternative<TrainingChunk>(*result)) {
primary_producer.Put(std::move(std::get<TrainingChunk>(*result)),
stop_token);
} else {
cachehit_producer->Put(std::move(std::get<FrameType>(*result)),
stop_token);
}
}
} catch (const QueueClosedException&) {
LOG(INFO) << "OutputWorker stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "OutputWorker stopping, request cancelled.";
} catch (const std::exception& e) {
LOG(FATAL) << "OutputWorker encountered an error: " << e.what();
}
}
void ShufflingChunkPool::CachingWorker(std::stop_token stop_token,
CachingThreadContext* context) {
constexpr double kTheta = 0.99;
double reminder = 0.0;
double exponential_avg_probability = 1.0;
try {
while (true) {
auto cache_request = [&]() {
LoadMetricPauser pauser(context->load_metric_updater);
return cache_request_queue_->Get(stop_token);
}();
std::shared_ptr<ChunkSourceItem> source_item;
float max_weight = 0.0f;
size_t local_index = 0;
{
absl::MutexLock lock(&chunk_sources_mutex_);
// Find the chunk source containing this global index.
auto it = absl::c_lower_bound(
chunk_sources_, cache_request.global_index,
[](const auto& item, size_t chunk_idx) {
return item->start_chunk_index + item->source->GetChunkCount() <=
chunk_idx;
});
if (it == chunk_sources_.end() ||
cache_request.global_index < (*it)->start_chunk_index) {
chunk_source_not_found_.fetch_add(1, std::memory_order_acq_rel);
continue;
}
source_item = *it;
max_weight = max_weight_;
local_index = cache_request.global_index - source_item->start_chunk_index;
}
absl::MutexLock item_lock(&source_item->mutex);
assert(local_index < source_item->use_counts.size());
// Check use_count match.
if (source_item->use_counts[local_index] != cache_request.next_use) {
mismatched_use_counts_.fetch_add(1, std::memory_order_acq_rel);
continue;
}
// Compute how many positions to cache.
const float weight = source_item->weight[local_index];
assert(weight >= 0.0f);
const double probability = ComputeHanseProbability(weight, max_weight);
exponential_avg_probability =
exponential_avg_probability * (1.0 - kTheta) + probability * kTheta;
const double n = (probability * config_.position_cache_size() /
chunk_pool_size_ / exponential_avg_probability) +
reminder;
reminder = n - std::floor(n);
const size_t positions_to_cache = static_cast<size_t>(std::floor(n));
// Traverse and extend the cache chain.
std::unique_ptr<CacheNode>* current = &source_item->cache[local_index];
for (size_t i = 0; i < positions_to_cache; ++i) {
if (*current) {
current = &(*current)->next;
continue;
}
if (i >= cache_request.items.size()) break;
auto node = std::make_unique<CacheNode>();
node->frame = cache_request.items[i];
*current = std::move(node);
current = &(*current)->next;
newly_cached_.fetch_add(1, std::memory_order_acq_rel);
cached_positions_.fetch_add(1, std::memory_order_acq_rel);
}
const size_t dropped =
cache_request.items.size() > positions_to_cache
? cache_request.items.size() - positions_to_cache
: 0;
dropped_cache_positions_.fetch_add(dropped, std::memory_order_acq_rel);
}
} catch (const QueueClosedException&) {
LOG(INFO) << "CachingWorker stopping, queue closed.";
} catch (const QueueRequestCancelled&) {
LOG(INFO) << "CachingWorker stopping, request cancelled.";
}
}
struct ShufflingChunkPool::ChunkData {
std::vector<FrameType> data;
std::string sort_key;
size_t local_index = 0;
size_t global_index = 0;
uint32_t use_count = 0;
std::shared_ptr<ChunkSourceItem> source_item;
};
std::optional<std::variant<TrainingChunk, FrameType>>
ShufflingChunkPool::GetNextChunkData() {
while (true) {
ChunkData chunk_data;
const ChunkStatus status = GetChunkInfo(chunk_data);
if (status == ChunkStatus::kEnd) return std::nullopt;
if (status == ChunkStatus::kRetry) continue;
const bool hanse_enabled = config_.hanse_sampling_threshold() > 0;
if (hanse_enabled && !HanseAccept(chunk_data)) continue;
{
absl::MutexLock lock(&chunk_data.source_item->mutex);
assert(chunk_data.source_item->use_counts.size() > chunk_data.local_index);
chunk_data.use_count =
chunk_data.source_item->use_counts[chunk_data.local_index]++;
if (cachehit_output_queue_.has_value()) {
auto& cache_chain = chunk_data.source_item->cache[chunk_data.local_index];
if (cache_chain) {
cache_hits_.fetch_add(1, std::memory_order_acq_rel);
cached_positions_.fetch_sub(1, std::memory_order_acq_rel);
FrameType cached_frame = cache_chain->frame;
cache_chain = std::move(cache_chain->next);
return cached_frame;
}
cache_misses_.fetch_add(1, std::memory_order_acq_rel);
}
}
if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) continue;
TrainingChunk chunk;
chunk.sort_key = std::move(chunk_data.sort_key);
chunk.index_within_sort_key = chunk_data.local_index;
chunk.use_count = chunk_data.use_count;
chunk.global_index = chunk_data.global_index;
chunk.frames = std::move(chunk_data.data);
return chunk;
}
}
bool ShufflingChunkPool::LoadChunkData(ChunkData& chunk_data) {
std::optional<std::vector<FrameType>> data =
chunk_data.source_item->source->GetChunkData(chunk_data.local_index);
if (!data || data->empty()) {
absl::MutexLock lock(&chunk_data.source_item->mutex);
chunk_data.source_item->dropped_chunks.insert(chunk_data.local_index);
dropped_chunks_metric_.fetch_add(1, std::memory_order_acq_rel);
return false;
}
chunk_data.data = std::move(*data);
return true;
}
ShufflingChunkPool::ChunkStatus ShufflingChunkPool::GetChunkInfo(
ChunkData& out_chunk_data) {
std::shared_ptr<ChunkSourceItem> source_item;
{
absl::MutexLock lock(&chunk_sources_mutex_);
std::optional<size_t> chunk_index = stream_shuffler_.GetNextItem();
if (!chunk_index && !chunk_sources_.empty()) {
size_t total_chunks = chunk_sources_.back()->start_chunk_index +
chunk_sources_.back()->source->GetChunkCount();
size_t lower_bound = total_chunks > chunk_pool_size_
? total_chunks - chunk_pool_size_
: chunk_sources_.front()->start_chunk_index;
stream_shuffler_.Reset(lower_bound, total_chunks);
reshuffles_.fetch_add(1, std::memory_order_acq_rel);
chunk_index = stream_shuffler_.GetNextItem();
}
if (!chunk_index) return ChunkStatus::kEnd;
auto it =
absl::c_lower_bound(chunk_sources_, *chunk_index,
[](const auto& item, size_t chunk_idx) {
return item->start_chunk_index +
item->source->GetChunkCount() <=
chunk_idx;
});
if (ABSL_PREDICT_FALSE(it == chunk_sources_.end() ||
*chunk_index < (*it)->start_chunk_index)) {
LOG(WARNING) << "Chunk index " << *chunk_index
<< " out of range for available chunk sources.";
return ChunkStatus::kRetry;
}
source_item = *it;
out_chunk_data.local_index = *chunk_index - source_item->start_chunk_index;
out_chunk_data.sort_key = source_item->source->GetChunkSortKey();
out_chunk_data.global_index = *chunk_index;
}
{
absl::MutexLock lock(&source_item->mutex);
if (source_item->dropped_chunks.contains(out_chunk_data.local_index)) {
return ChunkStatus::kRetry;
}
}
out_chunk_data.source_item = std::move(source_item);
return ChunkStatus::kOk;
}
double ShufflingChunkPool::ComputeHanseProbability(float weight,
float max_weight) const {
if (max_weight <= 0.0f) return 1.0;
return std::pow(weight / max_weight, config_.hanse_sampling_gamma());
}
float ShufflingChunkPool::ComputeChunkWeight(
absl::Span<const FrameType> frames) const {
return absl::c_accumulate(frames, 0.0f, [this](float sum, const auto& frame) {
return sum +
ComputePositionSamplingWeight(frame, config_.position_sampling());
});
}
bool ShufflingChunkPool::HanseAccept(ChunkData& chunk_data) {
assert(chunk_data.source_item);
float weight = -1.0f;
{
absl::MutexLock lock(&chunk_data.source_item->mutex);
assert(chunk_data.source_item->weight.size() > chunk_data.local_index);
weight = chunk_data.source_item->weight[chunk_data.local_index];
}
if (weight < 0.0f) {
if (chunk_data.data.empty() && !LoadChunkData(chunk_data)) return false;
weight = ComputeChunkWeight(chunk_data.data);
{
absl::MutexLock pool_lock(&chunk_sources_mutex_);
max_weight_ = std::max(max_weight_, weight);
AddSample(chunk_weight_stats_, static_cast<double>(weight));
}
{
absl::MutexLock lock(&chunk_data.source_item->mutex);
const float cached_weight =
chunk_data.source_item->weight[chunk_data.local_index];
if (cached_weight < 0.0f) {
chunk_data.source_item->weight[chunk_data.local_index] = weight;
} else {
weight = cached_weight;
}
}
hanse_cache_misses_.fetch_add(1, std::memory_order_acq_rel);
} else {
hanse_cache_hits_.fetch_add(1, std::memory_order_acq_rel);
}
float max_weight = 0.0f;
{
absl::MutexLock lock(&chunk_sources_mutex_);
max_weight = max_weight_;
}
const double p = ComputeHanseProbability(weight, max_weight);
const double u = absl::Uniform<double>(bitgen_, 0.0, 1.0);
if (u >= p) {
hanse_rejected_.fetch_add(1, std::memory_order_acq_rel);
return false;
}
return true;
}
void ShufflingChunkPool::AddNewChunkSource(std::unique_ptr<ChunkSource> source)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_) {
// Add new chunk source to the end of the deque.
size_t old_upper_bound = 0;
if (!chunk_sources_.empty()) {
const auto& last_source = chunk_sources_.back();
old_upper_bound =
last_source->start_chunk_index + last_source->source->GetChunkCount();
}
size_t count = source->GetChunkCount();
auto item = std::make_shared<ChunkSourceItem>();
item->start_chunk_index = old_upper_bound;
item->source = std::move(source);
item->use_counts = std::vector<uint16_t>(count, 0);
item->weight = std::vector<float>(count, -1.0f);
item->cache =
std::vector<std::unique_ptr<CacheNode>>(cachehit_output_queue_.has_value()
? count
: 0);
chunk_sources_.push_back(std::move(item));
// Calculate current window bounds.
size_t new_upper_bound = chunk_sources_.back()->start_chunk_index +
chunk_sources_.back()->source->GetChunkCount();
// Remove old chunks if window exceeds chunk_pool_size_.
while (!chunk_sources_.empty() && chunk_sources_.size() > 1) {
size_t window_start = chunk_sources_.front()->start_chunk_index +
chunk_sources_.front()->source->GetChunkCount();
size_t window_size = new_upper_bound - window_start;
if (window_size < chunk_pool_size_) break;
// Count cached positions in the evicted source.
if (cachehit_output_queue_.has_value()) {
size_t evicted_cached = 0;
absl::MutexLock item_lock(&chunk_sources_.front()->mutex);
for (const auto& cache_chain : chunk_sources_.front()->cache) {
const CacheNode* node = cache_chain.get();
while (node) {
++evicted_cached;
node = node->next.get();
}
}
cached_positions_.fetch_sub(evicted_cached, std::memory_order_acq_rel);
}
// Remove the oldest chunk source (front of deque).
chunk_sources_.pop_front();
}
// Update stream shuffler bounds with the sliding window.
size_t window_start = chunk_sources_.front()->start_chunk_index;
size_t new_lower_bound = new_upper_bound > chunk_pool_size_
? new_upper_bound - chunk_pool_size_
: window_start;
stream_shuffler_.SetUpperBound(new_upper_bound);
stream_shuffler_.SetLowerBound(new_lower_bound);
}
StageMetricProto ShufflingChunkPool::FlushMetrics() {
StageMetricProto stage_metric;
// Aggregate source ingestion load metrics from all ingestion threads.
LoadMetricProto ingestion_load;
ingestion_load.set_name("source_ingestion");
for (const auto& context : source_ingestion_thread_contexts_) {
UpdateFrom(ingestion_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(ingestion_load);
// Aggregate chunk loading load metrics from all chunk loading threads.
LoadMetricProto chunk_loading_load;
chunk_loading_load.set_name("chunk_loading");
for (const auto& context : chunk_loading_thread_contexts_) {
UpdateFrom(chunk_loading_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(chunk_loading_load);
// Get chunk sources statistics and pool state.
{
absl::MutexLock lock(&chunk_sources_mutex_);
auto* chunk_sources_metric = stage_metric.add_gauge_metrics();
chunk_sources_metric->set_name("chunk_sources");
chunk_sources_metric->set_value(
static_cast<uint64_t>(chunk_sources_.size()));
size_t upper = 0;
size_t current = 0;
if (!chunk_sources_.empty()) {
const auto& first = chunk_sources_.front();
const auto& last = chunk_sources_.back();
upper = last->start_chunk_index + last->source->GetChunkCount();
current = upper - first->start_chunk_index;
}
auto* current_chunks_metric = stage_metric.add_gauge_metrics();
current_chunks_metric->set_name("chunks_current");
current_chunks_metric->set_value(static_cast<uint64_t>(current));
current_chunks_metric->set_capacity(
static_cast<uint64_t>(chunk_pool_size_));
auto* total_chunks_metric = stage_metric.add_gauge_metrics();
total_chunks_metric->set_name("chunks_total");
total_chunks_metric->set_value(static_cast<uint64_t>(upper));
}
// Get anchor-related metrics.
{
absl::MutexLock lock(&anchor_mutex_);
auto* chunks_since_anchor_metric = stage_metric.add_gauge_metrics();
chunks_since_anchor_metric->set_name("chunks_since_anchor");
chunks_since_anchor_metric->set_value(chunks_since_anchor_);
stage_metric.set_anchor(anchor_);
}
auto* dropped_metric = stage_metric.add_count_metrics();
dropped_metric->set_name("dropped");
dropped_metric->set_count(
dropped_chunks_metric_.exchange(0, std::memory_order_acq_rel));
// Hanse sampling and shuffler metrics.
{
auto* hits = stage_metric.add_count_metrics();
hits->set_name("hanse_cache_hits");
hits->set_count(hanse_cache_hits_.exchange(0, std::memory_order_acq_rel));
auto* misses = stage_metric.add_count_metrics();
misses->set_name("hanse_cache_misses");
misses->set_count(
hanse_cache_misses_.exchange(0, std::memory_order_acq_rel));
auto* rejected = stage_metric.add_count_metrics();
rejected->set_name("hanse_rejected");
rejected->set_count(hanse_rejected_.exchange(0, std::memory_order_acq_rel));
auto* resh = stage_metric.add_count_metrics();
resh->set_name("reshuffles");
resh->set_count(reshuffles_.exchange(0, std::memory_order_acq_rel));
}
// Position cache metrics.
if (cachehit_output_queue_.has_value()) {
LoadMetricProto caching_load;
caching_load.set_name("caching");
for (const auto& context : caching_thread_contexts_) {
UpdateFrom(caching_load, context->load_metric_updater.FlushMetrics());
}
*stage_metric.add_load_metrics() = std::move(caching_load);
auto* cache_hits = stage_metric.add_count_metrics();
cache_hits->set_name("cache_hits");
cache_hits->set_count(cache_hits_.exchange(0, std::memory_order_acq_rel));
auto* cache_misses = stage_metric.add_count_metrics();
cache_misses->set_name("cache_misses");
cache_misses->set_count(
cache_misses_.exchange(0, std::memory_order_acq_rel));
auto* mismatched = stage_metric.add_count_metrics();
mismatched->set_name("mismatched_use_counts");
mismatched->set_count(
mismatched_use_counts_.exchange(0, std::memory_order_acq_rel));
auto* newly_cached = stage_metric.add_count_metrics();
newly_cached->set_name("newly_cached");
newly_cached->set_count(
newly_cached_.exchange(0, std::memory_order_acq_rel));
auto* dropped = stage_metric.add_count_metrics();
dropped->set_name("dropped_cache_positions");
dropped->set_count(
dropped_cache_positions_.exchange(0, std::memory_order_acq_rel));
auto* not_found = stage_metric.add_count_metrics();
not_found->set_name("chunk_source_not_found");
not_found->set_count(
chunk_source_not_found_.exchange(0, std::memory_order_acq_rel));
auto* cached = stage_metric.add_gauge_metrics();
cached->set_name("cached_positions");
cached->set_value(cached_positions_.load(std::memory_order_acquire));
cached->set_capacity(config_.position_cache_size());
}
{
absl::MutexLock lock(&chunk_sources_mutex_);
if (chunk_weight_stats_.count() > 0) {
chunk_weight_stats_.set_name("chunk_weight");
UpdateFrom(*stage_metric.add_statistics_metrics(), chunk_weight_stats_);
}
chunk_weight_stats_.Clear();
}
*stage_metric.add_queue_metrics() =
MetricsFromQueue(primary_output_name_, *output_queue());
if (cachehit_output_queue_.has_value()) {
*stage_metric.add_queue_metrics() =
MetricsFromQueue(*cachehit_output_name_, *cachehit_output_queue_);
}
return stage_metric;
}
std::pair<std::string, int> ShufflingChunkPool::ResetAnchor() {
absl::MutexLock anchor_lock(&anchor_mutex_);
absl::MutexLock sources_lock(&chunk_sources_mutex_);
if (chunk_sources_.empty()) {
int previous_count = chunks_since_anchor_.exchange(0);
return {anchor_, previous_count};
}
anchor_ = chunk_sources_.back()->source->GetChunkSortKey();
int previous_count = chunks_since_anchor_.exchange(0);
return {anchor_, previous_count};
}
int ShufflingChunkPool::ChunksSinceAnchor() { return chunks_since_anchor_; }
std::string ShufflingChunkPool::CurrentAnchor() {
absl::MutexLock lock(&anchor_mutex_);
return anchor_;
}
void ShufflingChunkPool::SetAnchor(std::string_view anchor) {
absl::MutexLock lock(&anchor_mutex_);
anchor_ = anchor;
}
std::optional<StageControlResponse> ShufflingChunkPool::Control(
const StageControlRequest& request) {
if (!request.has_chunk_pool_request()) {
return std::nullopt;
}
const auto& chunk_request = request.chunk_pool_request();
StageControlResponse response;
auto* chunk_response = response.mutable_chunk_pool_response();
if (chunk_request.reset_chunk_anchor()) {
auto [anchor, chunks] = ResetAnchor();
chunk_response->set_chunk_anchor(anchor);
chunk_response->set_chunks_since_anchor(chunks);
return response;
}
if (chunk_request.has_set_chunk_anchor()) {
SetAnchor(chunk_request.set_chunk_anchor());
chunk_response->set_chunk_anchor(chunk_request.set_chunk_anchor());
chunk_response->set_chunks_since_anchor(ChunksSinceAnchor());
return response;
}
chunk_response->set_chunk_anchor(CurrentAnchor());
chunk_response->set_chunks_since_anchor(ChunksSinceAnchor());
return response;
}
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/shuffling_chunk_pool.h
================================================
#pragma once
#include <atomic>
#include <filesystem>
#include <memory>
#include <optional>
#include <stop_token>
#include <string>
#include <string_view>
#include <thread>
#include <variant>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/random.h"
#include "absl/synchronization/mutex.h"
#include "loader/chunk_source/chunk_source.h"
#include "loader/data_loader_metrics.h"
#include "loader/stages/chunk_source_loader.h"
#include "loader/stages/stage.h"
#include "loader/stages/training_chunk.h"
#include "proto/data_loader_config.pb.h"
#include "proto/stage_control.pb.h"
#include "proto/training_metrics.pb.h"
#include "utils/metrics/load_metric.h"
#include "utils/queue.h"
#include "utils/stream_shuffler.h"
#include "utils/thread_pool.h"
namespace lczero {
namespace training {
class ShufflingChunkPool : public Stage {
public:
explicit ShufflingChunkPool(const ShufflingChunkPoolConfig& config);
~ShufflingChunkPool();
void Start() override;
void Stop() override;
void SetInputs(absl::Span<QueueBase* const> inputs) override;
QueueBase* GetOutput(std::string_view name) override;
StageMetricProto FlushMetrics() override;
std::optional<StageControlResponse> Control(
const StageControlRequest& request) override;
// Anchor management methods for tracking chunks since a specific point.
std::pair<std::string, int> ResetAnchor();
int ChunksSinceAnchor();
std::string CurrentAnchor();
void SetAnchor(std::string_view anchor);
Queue<ChunkSourceWithPhase>* input_queue() { return primary_input_queue_; }
Queue<TrainingChunk>* output_queue() { return &primary_output_queue_; }
private:
struct CacheNode {
FrameType frame;
std::unique_ptr<CacheNode> next;
};
struct ChunkSourceItem {
mutable absl::Mutex mutex;
size_t start_chunk_index;
std::unique_ptr<ChunkSource> source;
absl::flat_hash_set<size_t> dropped_chunks ABSL_GUARDED_BY(mutex);
// Per-chunk counters and cached weights.
std::vector<uint16_t> use_counts ABSL_GUARDED_BY(mutex);
std::vector<float> weight ABSL_GUARDED_BY(mutex);
std::vector<std::unique_ptr<CacheNode>> cache ABSL_GUARDED_BY(mutex);
};
struct SourceIngestionThreadContext {
LoadMetricUpdater load_metric_updater;
};
struct ChunkLoadingThreadContext {
LoadMetricUpdater load_metric_updater;
};
struct CachingThreadContext {
LoadMetricUpdater load_metric_updater;
};
std::vector<std::unique_ptr<ChunkSource>> InitializeChunkSources();
void ProcessInputFiles(
std::vector<std::unique_ptr<ChunkSource>> uninitialized_sources);
void SourceIngestionWorker(std::stop_token stop_token,
SourceIngestionThreadContext* context);
void OutputWorker(std::stop_token stop_token,
ChunkLoadingThreadContext* context);
void CachingWorker(std::stop_token stop_token, CachingThreadContext* context);
void AddNewChunkSource(std::unique_ptr<ChunkSource> source)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_);
std::optional<std::variant<TrainingChunk, FrameType>> GetNextChunkData();
enum class ChunkStatus { kOk, kRetry, kEnd };
struct ChunkData;
ChunkStatus GetChunkInfo(ChunkData& out_chunk_data);
bool LoadChunkData(ChunkData& chunk_data);
bool HanseAccept(ChunkData& chunk_data);
float ComputeChunkWeight(absl::Span<const FrameType> frames) const;
double ComputeHanseProbability(float weight, float max_weight) const;
Queue<ChunkSourceWithPhase>* primary_input_queue_ = nullptr;
Queue<CacheRequest>* cache_request_queue_ = nullptr;
std::string primary_output_name_;
Queue<TrainingChunk> primary_output_queue_;
std::optional<std::string> cachehit_output_name_;
std::optional<Queue<FrameType>> cachehit_output_queue_;
const size_t chunk_pool_size_;
const ShufflingChunkPoolConfig config_;
// stop_source_ must be declared before ThreadPools that reference it.
std::stop_source stop_source_;
ThreadPool source_ingestion_pool_;
ThreadPool chunk_loading_pool_;
ThreadPool caching_pool_;
std::atomic<int64_t> dropped_chunks_metric_{0};
absl::Mutex chunk_sources_mutex_;
std::deque<std::shared_ptr<ChunkSourceItem>> chunk_sources_
ABSL_GUARDED_BY(chunk_sources_mutex_);
StreamShuffler stream_shuffler_ ABSL_GUARDED_BY(chunk_sources_mutex_);
float max_weight_ ABSL_GUARDED_BY(chunk_sources_mutex_) = 0.0f;
std::jthread initialization_thread_;
std::vector<std::unique_ptr<SourceIngestionThreadContext>>
source_ingestion_thread_contexts_;
std::vector<std::unique_ptr<ChunkLoadingThreadContext>>
chunk_loading_thread_contexts_;
std::vector<std::unique_ptr<CachingThreadContext>> caching_thread_contexts_;
// Anchor-related members for tracking chunks since a specific point.
absl::Mutex anchor_mutex_;
std::string anchor_ ABSL_GUARDED_BY(anchor_mutex_);
std::atomic<int> chunks_since_anchor_{0};
// Thread-local RNG for Hanse sampling.
static thread_local absl::BitGen bitgen_;
// Metrics counters.
std::atomic<uint64_t> hanse_cache_hits_{0};
std::atomic<uint64_t> hanse_cache_misses_{0};
std::atomic<uint64_t> hanse_rejected_{0};
std::atomic<uint64_t> reshuffles_{0};
std::atomic<uint64_t> cache_hits_{0};
std::atomic<uint64_t> cache_misses_{0};
std::atomic<uint64_t> mismatched_use_counts_{0};
std::atomic<uint64_t> newly_cached_{0};
std::atomic<uint64_t> dropped_cache_positions_{0};
std::atomic<uint64_t> chunk_source_not_found_{0};
std::atomic<uint64_t> cached_positions_{0};
StatisticsProtoDouble chunk_weight_stats_
ABSL_GUARDED_BY(chunk_sources_mutex_);
};
} // namespace training
} // namespace lczero
================================================
FILE: csrc/loader/stages/shuffling_chunk_pool_test.cc
================================================
// ABOUTME: Comprehensive unit tests for the ShufflingChunkPool class
// ABOUTME: Tests chunk source management, output workers, and dynamic windowing
#include "loader/stages/shuffling_chunk_pool.h"
#include <absl/cleanup/cleanup.h>
#include <absl/log/log.h>
#include <gtest/gtest.h>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <memory>
#include <set>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "loader/stages/training_chunk.h"
namespace lczero {
namespace training {
namespace {
template <typename T>
class PassthroughStage : public Stage {
public:
explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
void Start() override {}
void Stop() override {}
StageMetricProto FlushMetrics() override { return StageMetricProto(); }
QueueBase* GetOutput(std::string_view name = "") override {
(void)name;
return queue_;
}
void SetInputs(absl::Span<QueueBase* const> inputs) override {
if (!inputs.empty()) {
throw std::runtime_error("PassthroughStage expects no inputs");
}
}
private:
Queue<T>* queue_;
};
} // namespace
// Mock ChunkSource for testing
class MockChunkSource : public ChunkSource {
public:
MockChunkSource(const std::string& sort_key, size_t chunk_count)
: sort_key_(sort_key), chunk_count_(chunk_count) {}
std::string GetChunkSortKey() const override { return sort_key_; }
size_t GetChunkCount() const override { return chunk_count_; }
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {
if (index >= chunk_count_) {
throw std::out_of_range("Chunk index out of range");
}
FrameType frame{};
frame.version = static_cast<uint32_t>(index);
frame.input_format = 3;
return std::vector<FrameType>{frame};
}
private:
std::string sort_key_;
size_t chunk_count_;
};
class InvalidChunkSource : public ChunkSource {
public:
explicit InvalidChunkSource(std::string sort_key)
: sort_key_(std::move(sort_key)) {}
std::string GetChunkSortKey() const override { return sort_key_; }
size_t GetChunkCount() const override { return 2; }
std::optional<std::vector<FrameType>> GetChunkData(size_t index) override {
if (index >= 2) {
throw std::out_of_range("Chunk index out of range");
}
if (index == 0) {
return std::nullopt;
}
FrameType frame{};
frame.version = 42;
return std::vector<FrameType>{frame};
}
private:
std::string sort_key_;
};
class ShufflingChunkPoolTest : public ::testing::Test {
protected:
void SetUp() override {
input_queue_ = std::make_unique<Queue<ChunkSourceWithPhase>>(100);
input_producer_ = std::make_unique<Queue<ChunkSourceWithPhase>::Producer>(
input_queue_->CreateProducer());
}
void TearDown() override {
// Close the producer to close the queue
if (input_producer_) input_producer_.reset();
}
// Helper to add a mock chunk source to the input queue
void AddMockChunkSourceToQueue(const std::string& sort_key,
size_t chunk_count,
FilePathProvider::MessageType message_type =
FilePathProvider::MessageType::kFile) {
ChunkSourceWithPhase item;
item.source = std::make_unique<MockChunkSource>(sort_key, chunk_count);
item.message_type = message_type;
input_producer_->Put(std::move(item));
}
void MarkInitialScanComplete() {
ChunkSourceWithPhase item;
item.source = nullptr; // No source for completion marker
item.message_type = FilePathProvider::MessageType::kInitialScanComplete;
input_producer_->Put(std::move(item));
}
void CloseInputQueue() {
if (input_producer_) input_producer_.reset();
}
ShufflingChunkPoolConfig MakeConfig(int chunk_pool_size,
int source_ingestion_threads = 1,
int loading_threads = 1,
int queue_capacity = 100) const {
ShufflingChunkPoolConfig config;
config.set_chunk_pool_size(chunk_pool_size);
config.set_source_ingestion_threads(source_ingestion_threads);
config.set_chunk_loading_threads(loading_threads);
config.mutable_output()->set_queue_capacity(queue_capacity);
return config;
}
std::unique_ptr<Queue<ChunkSourceWithPhase>> input_queue_;
std::unique_ptr<Queue<ChunkSourceWithPhase>::Producer> input_producer_;
};
TEST_F(ShufflingChunkPoolTest, ConstructorCreatesOutputQueue) {
// Add some mock chunk sources with enough chunks
AddMockChunkSourceToQueue("source1", 50);
AddMockChunkSourceToQueue("source2", 60);
MarkInitialScanComplete();
auto config = MakeConfig(20);
ShufflingChunkPool shuffling_chunk_pool(config);
shuffling_chunk_pool.SetInputs({input_queue_.get()});
auto* output_queue = shuffling_chunk_pool.output_queue();
// Close input queue to stop input worker from waiting
CloseInputQueue();
EXPECT_NE(output_queue, nullptr);
EXPECT_EQ(output_queue->Capacity(), 100);
// Drain output queue to prevent workers from blocking
try {
while (output_queue->Size() > 0) {
output_queue->Get();
}
} catch (const QueueClosedException&) {
// Queue closed, that's fine
}
}
TEST_F(ShufflingChunkPoolTest, HandlesEmptyInputQueue) {
// Only mark scan complete, no chunk sources
MarkInitialScanComplete();
auto config = MakeConfig(20);
// Constructor should now succeed (initialization is asynchronous)
ShufflingChunkPool shuffling_chunk_pool(config);
shuffling_chunk_pool.SetInputs({input_queue_.get()});
shuffling_chunk_pool.Start();
// The initialization thread should handle the error case
auto* output_queue = shuffling_chunk_pool.output_queue();
// Give the initialization thread time to complete and discover the error
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// Close input queue to clean up
CloseInputQueue();
// Output queue should exist but be closed to signal startup failure when no
// chunks were found.
EXPECT_NE(output_queue, nullptr);
EXPECT_TRUE(output_queue->IsClosed());
EXPECT_EQ(output_queue->Size(), 0u);
}
TEST_F(ShufflingChunkPoolTest, FlushMetricsHandlesEmptyChunkSources) {
const int chunk_pool_size = 32;
auto config = MakeConfig(chunk_pool_size);
ShufflingChunkPool shuffling_chunk_pool(config);
shuffling_chunk_pool.SetInputs({input_queue_.get()});
auto metrics = shuffling_chunk_pool.FlushMetrics();
bool found_current = false;
bool found_total = false;
for (const auto& metric : metrics.gaug
gitextract_k9ephqoc/
├── .clang-format
├── .gitignore
├── .gitmodules
├── AGENTS.md
├── README.md
├── csrc/
│ ├── loader/
│ │ ├── chunk_source/
│ │ │ ├── chunk_source.h
│ │ │ ├── chunk_source_view.h
│ │ │ ├── debug_chunk_source.cc
│ │ │ ├── debug_chunk_source.h
│ │ │ ├── rawfile_chunk_source.cc
│ │ │ ├── rawfile_chunk_source.h
│ │ │ ├── tar_chunk_source.cc
│ │ │ └── tar_chunk_source.h
│ │ ├── data_loader.cc
│ │ ├── data_loader.h
│ │ ├── data_loader_metrics.cc
│ │ ├── data_loader_metrics.h
│ │ ├── data_loader_test.cc
│ │ ├── frame_type.h
│ │ ├── loader_main.cpp
│ │ ├── pybind_module.cc
│ │ └── stages/
│ │ ├── chunk_rescorer.cc
│ │ ├── chunk_rescorer.h
│ │ ├── chunk_rescorer_test.cc
│ │ ├── chunk_source_loader.cc
│ │ ├── chunk_source_loader.h
│ │ ├── chunk_source_loader_test.cc
│ │ ├── chunk_source_splitter.cc
│ │ ├── chunk_source_splitter.h
│ │ ├── chunk_source_splitter_test.cc
│ │ ├── chunk_unpacker.cc
│ │ ├── chunk_unpacker.h
│ │ ├── chunk_unpacker_test.cc
│ │ ├── file_path_provider.cc
│ │ ├── file_path_provider.h
│ │ ├── file_path_provider_main.cc
│ │ ├── file_path_provider_test.cc
│ │ ├── join_stage.cc
│ │ ├── join_stage.h
│ │ ├── join_stage_test.cc
│ │ ├── position_sampling.cc
│ │ ├── position_sampling.h
│ │ ├── shuffling_chunk_pool.cc
│ │ ├── shuffling_chunk_pool.h
│ │ ├── shuffling_chunk_pool_test.cc
│ │ ├── shuffling_frame_sampler.cc
│ │ ├── shuffling_frame_sampler.h
│ │ ├── shuffling_frame_sampler_test.cc
│ │ ├── simple_chunk_extractor.cc
│ │ ├── simple_chunk_extractor.h
│ │ ├── simple_chunk_extractor_test.cc
│ │ ├── stage.cc
│ │ ├── stage.h
│ │ ├── stage_factory.cc
│ │ ├── stage_factory.h
│ │ ├── stage_factory_test.cc
│ │ ├── tensor_generator.cc
│ │ ├── tensor_generator.h
│ │ ├── tensor_generator_test.cc
│ │ └── training_chunk.h
│ ├── tools/
│ │ ├── dump_chunk_main.cc
│ │ ├── filter_chunks_main.cc
│ │ ├── position_weight_stats_main.cc
│ │ ├── rescore_chunk_main.cc
│ │ ├── result_distribution_main.cc
│ │ └── startpos_policy_distribution_main.cc
│ └── utils/
│ ├── gz.cc
│ ├── gz.h
│ ├── metrics/
│ │ ├── exponential_aggregator.h
│ │ ├── group.h
│ │ ├── load_metric.h
│ │ ├── load_metric_test.cc
│ │ ├── printer.h
│ │ ├── statistics_metric.h
│ │ └── stats_test.cc
│ ├── queue.h
│ ├── queue_test.cc
│ ├── stream_shuffler.cc
│ ├── stream_shuffler.h
│ ├── stream_shuffler_test.cc
│ ├── tensor.h
│ ├── tensor_test.cc
│ ├── thread_pool.h
│ ├── training_data_printer.cc
│ └── training_data_printer.h
├── docs/
│ ├── README.md
│ ├── architecture.md
│ ├── checkpoint_migration.md
│ ├── example.textproto
│ ├── heads.md
│ ├── index.md
│ ├── loader.md
│ ├── new_stage.md
│ ├── overview.md
│ ├── shuffling_pool_hanse_sampling.md
│ ├── training_tuple.md
│ ├── tui.md
│ └── weights_tool.md
├── init.sh
├── justfile
├── meson.build
├── native.ini
├── proto/
│ ├── checkpoint_migration_config.proto
│ ├── data_loader_config.proto
│ ├── export_config.proto
│ ├── metrics_config.proto
│ ├── model_config.proto
│ ├── root_config.proto
│ ├── stage_control.proto
│ ├── training_config.proto
│ └── training_metrics.proto
├── pyproject.toml
├── scripts/
│ ├── diff.py
│ ├── fixorder.py
│ ├── init.sh
│ ├── initsplit.py
│ ├── inittrainingname.py
│ ├── pack.py
│ ├── purge.py
│ ├── rescore.sh
│ ├── shuffle.py
│ ├── split.sh
│ ├── stage.sh
│ ├── unpack.py
│ └── upload.sh
├── src/
│ ├── lczero_training/
│ │ ├── __init__.py
│ │ ├── _lczero_training.pyi
│ │ ├── commands/
│ │ │ ├── __init__.py
│ │ │ ├── backfill_metrics.py
│ │ │ ├── common.py
│ │ │ ├── daemon.py
│ │ │ ├── dataloader_viz.py
│ │ │ ├── describe_training.py
│ │ │ ├── jax2leela.py
│ │ │ ├── leela2jax.py
│ │ │ ├── migrate_checkpoint.py
│ │ │ ├── overfit.py
│ │ │ ├── test_dataloader.py
│ │ │ ├── train.py
│ │ │ ├── training_eval.py
│ │ │ ├── training_init.py
│ │ │ ├── tui.py
│ │ │ ├── tune_lr.py
│ │ │ └── weights_tool.py
│ │ ├── convert/
│ │ │ ├── __init__.py
│ │ │ ├── jax_to_leela.py
│ │ │ ├── leela_pytree_visitor.py
│ │ │ ├── leela_to_jax.py
│ │ │ └── leela_to_modelconfig.py
│ │ ├── daemon/
│ │ │ ├── __init__.py
│ │ │ ├── daemon.py
│ │ │ ├── metrics.py
│ │ │ ├── metrics_base.py
│ │ │ ├── pipeline.py
│ │ │ ├── protocol/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── communicator.py
│ │ │ │ ├── messages.py
│ │ │ │ └── registry.py
│ │ │ └── rms_metrics.py
│ │ ├── dataloader/
│ │ │ └── __init__.py
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ ├── embedding.py
│ │ │ ├── encoder.py
│ │ │ ├── loss_function.py
│ │ │ ├── model.py
│ │ │ ├── movesleft_head.py
│ │ │ ├── policy_head.py
│ │ │ ├── shared.py
│ │ │ ├── utils.py
│ │ │ └── value_head.py
│ │ ├── py.typed
│ │ ├── tests/
│ │ │ ├── test_protobuf.py
│ │ │ ├── test_protocol_registry.py
│ │ │ └── test_weights_tool.py
│ │ ├── tools/
│ │ │ ├── __init__.py
│ │ │ ├── weight_codecs.py
│ │ │ ├── weight_wrappers.py
│ │ │ └── weights_tool.py
│ │ ├── training/
│ │ │ ├── __init__.py
│ │ │ ├── backfill_metrics.py
│ │ │ ├── dataloader_probe.py
│ │ │ ├── describe.py
│ │ │ ├── eval.py
│ │ │ ├── init.py
│ │ │ ├── lr_schedule.py
│ │ │ ├── migrate_checkpoint.py
│ │ │ ├── optimizer.py
│ │ │ ├── overfit.py
│ │ │ ├── state.py
│ │ │ ├── tensorboard.py
│ │ │ ├── test_lr_schedule.py
│ │ │ ├── training.py
│ │ │ ├── tune_lr.py
│ │ │ └── utils.py
│ │ └── tui/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── app.tcss
│ │ ├── data_pipeline_pane.py
│ │ ├── dataloader_widgets.py
│ │ ├── log_pane.py
│ │ └── training_widgets.py
│ └── proto/
│ └── __init__.py
└── tf/
├── attention_policy_map.py
├── chunkparsefunc.py
├── chunkparser.py
├── configs/
│ └── example.yaml
├── decode_training.py
├── lc0_az_policy_map.py
├── make_model.py
├── model_to_net.py
├── net.py
├── net_to_model.py
├── policy_index.py
├── requirements.txt
├── shufflebuffer.py
├── start.sh
├── tfprocess.py
├── train.py
└── update_steps.py
SYMBOL INDEX (1253 symbols across 163 files)
FILE: csrc/loader/chunk_source/chunk_source.h
function namespace (line 10) | namespace lczero {
FILE: csrc/loader/chunk_source/chunk_source_view.h
function namespace (line 12) | namespace lczero {
FILE: csrc/loader/chunk_source/debug_chunk_source.cc
type lczero (line 13) | namespace lczero {
type training (line 14) | namespace training {
FILE: csrc/loader/chunk_source/debug_chunk_source.h
function namespace (line 11) | namespace lczero {
FILE: csrc/loader/chunk_source/rawfile_chunk_source.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
FILE: csrc/loader/chunk_source/rawfile_chunk_source.h
function namespace (line 9) | namespace lczero {
FILE: csrc/loader/chunk_source/tar_chunk_source.cc
type lczero (line 19) | namespace lczero {
type training (line 20) | namespace training {
type TarHeader (line 22) | struct TarHeader {
function ParseOctal (line 43) | uint64_t ParseOctal(const std::array<uint8_t, 12>& octal) {
function ReadExact (line 52) | bool ReadExact(int fd, off_t offset, void* buffer, size_t size) {
function ReadGzipPrefix (line 64) | std::optional<std::string> ReadGzipPrefix(int fd, off_t offset, size...
FILE: csrc/loader/chunk_source/tar_chunk_source.h
function namespace (line 12) | namespace lczero {
FILE: csrc/loader/data_loader.cc
type lczero (line 15) | namespace lczero {
type training (line 16) | namespace training {
function DataLoaderConfig (line 17) | DataLoaderConfig DataLoader::ParseConfig(const std::string& serializ...
function TensorTuple (line 102) | TensorTuple DataLoader::GetNext(std::string_view alias) {
FILE: csrc/loader/data_loader.h
function namespace (line 18) | namespace lczero {
FILE: csrc/loader/data_loader_metrics.cc
type lczero (line 11) | namespace lczero {
type training (line 12) | namespace training {
function ProtoT (line 16) | ProtoT* FindByName(std::vector<ProtoT>* entries, absl::string_view n...
function UpdateFrom (line 25) | void UpdateFrom(QueueMetricProto& dest, const QueueMetricProto& src) {
function UpdateFrom (line 34) | void UpdateFrom(CountMetricProto& dest, const CountMetricProto& src) {
function UpdateFrom (line 39) | void UpdateFrom(GaugeMetricProto& dest, const GaugeMetricProto& src) {
function UpdateFrom (line 45) | void UpdateFrom(StageMetricProto& dest, const StageMetricProto& src) {
function UpdateFrom (line 108) | void UpdateFrom(DataLoaderMetricsProto& dest,
FILE: csrc/loader/data_loader_metrics.h
function namespace (line 13) | namespace lczero {
FILE: csrc/loader/data_loader_test.cc
type lczero (line 7) | namespace lczero {
type training (line 8) | namespace training {
function TEST (line 10) | TEST(DataLoaderTest, AllowsNoOutputsConfigured) {
function TEST (line 19) | TEST(DataLoaderTest, ThrowsOnDuplicateStageName) {
FILE: csrc/loader/frame_type.h
function namespace (line 32) | namespace lczero {
FILE: csrc/loader/loader_main.cpp
type lczero (line 25) | namespace lczero {
type training (line 26) | namespace training {
function Run (line 28) | void Run() {
function main (line 115) | int main(int argc, char* argv[]) {
FILE: csrc/loader/pybind_module.cc
type lczero (line 27) | namespace lczero {
type training (line 28) | namespace training {
function SerializePyProto (line 32) | std::string SerializePyProto(const py::handle& obj, const char* expe...
function ProtoT (line 45) | ProtoT ParsePyProto(const py::handle& obj, const char* expected_type) {
function MakePythonProto (line 51) | py::object MakePythonProto(const char* module_name, const char* mess...
function tensor_to_numpy (line 62) | py::array tensor_to_numpy(std::unique_ptr<TensorBase> tensor) {
function tensor_tuple_to_numpy_tuple (line 75) | py::tuple tensor_tuple_to_numpy_tuple(TensorTuple tensor_tuple) {
function PYBIND11_MODULE (line 83) | PYBIND11_MODULE(_lczero_training, m) {
FILE: csrc/loader/stages/chunk_rescorer.cc
type lczero (line 11) | namespace lczero {
type training (line 12) | namespace training {
function V6ToV7 (line 15) | void V6ToV7(std::span<FrameType> data, float theta = 5.0f / 6.0f) {
function StageMetricProto (line 158) | StageMetricProto ChunkRescorer::FlushMetrics() {
FILE: csrc/loader/stages/chunk_rescorer.h
function namespace (line 23) | namespace lczero {
FILE: csrc/loader/stages/chunk_rescorer_test.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
class PassthroughStage (line 18) | class PassthroughStage : public Stage {
method PassthroughStage (line 20) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 22) | void Start() override {}
method Stop (line 23) | void Stop() override {}
method StageMetricProto (line 24) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 25) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 29) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
class ChunkRescorerTest (line 41) | class ChunkRescorerTest : public ::testing::Test {
method SetUp (line 43) | void SetUp() override {
method TrainingChunk (line 54) | TrainingChunk MakeChunk(std::vector<FrameType> frames,
function TEST_F (line 69) | TEST_F(ChunkRescorerTest, HandlesInputQueueClosure) {
FILE: csrc/loader/stages/chunk_source_loader.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
function CreateChunkSourceFromFile (line 15) | std::unique_ptr<ChunkSource> CreateChunkSourceFromFile(
function StageMetricProto (line 163) | StageMetricProto ChunkSourceLoader::FlushMetrics() {
FILE: csrc/loader/stages/chunk_source_loader.h
function namespace (line 19) | namespace lczero {
FILE: csrc/loader/stages/chunk_source_loader_test.cc
type lczero (line 10) | namespace lczero {
type training (line 11) | namespace training {
class PassthroughStage (line 16) | class PassthroughStage : public Stage {
method PassthroughStage (line 18) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 20) | void Start() override {}
method Stop (line 21) | void Stop() override {}
method StageMetricProto (line 22) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 23) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 27) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
function TEST (line 39) | TEST(ChunkSourceLoaderTest, ProcessesFiles) {
function TEST (line 73) | TEST(ChunkSourceLoaderTest, HandlesPhases) {
function TEST (line 105) | TEST(ChunkSourceLoaderTest, PassesThroughInitialScanComplete) {
function TEST (line 136) | TEST(ChunkSourceLoaderTest, SentinelBarrierWithMultipleThreads) {
FILE: csrc/loader/stages/chunk_source_splitter.cc
type lczero (line 11) | namespace lczero {
type training (line 12) | namespace training {
function QueueBase (line 83) | QueueBase* ChunkSourceSplitter::GetOutput(std::string_view name) {
function StageMetricProto (line 93) | StageMetricProto ChunkSourceSplitter::FlushMetrics() {
FILE: csrc/loader/stages/chunk_source_splitter.h
function namespace (line 22) | namespace lczero {
FILE: csrc/loader/stages/chunk_source_splitter_test.cc
type lczero (line 16) | namespace lczero {
type training (line 17) | namespace training {
class FixedCountChunkSource (line 21) | class FixedCountChunkSource : public ChunkSource {
method FixedCountChunkSource (line 23) | FixedCountChunkSource(std::string sort_key, size_t count)
method GetChunkSortKey (line 27) | std::string GetChunkSortKey() const override { return key_; }
method GetChunkCount (line 28) | size_t GetChunkCount() const override { return count_; }
method GetChunkData (line 29) | std::optional<std::vector<FrameType>> GetChunkData(size_t) override {
class PassthroughStage (line 38) | class PassthroughStage : public Stage {
method PassthroughStage (line 40) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 42) | void Start() override {}
method Stop (line 43) | void Stop() override {}
method StageMetricProto (line 44) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 45) | QueueBase* GetOutput(std::string_view) override { return queue_; }
method SetInputs (line 46) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
function TEST (line 58) | TEST(ChunkSourceSplitterTest, SplitsByHashAndWeight) {
function TEST (line 122) | TEST(ChunkSourceSplitterTest, BroadcastsInitialScanComplete) {
FILE: csrc/loader/stages/chunk_unpacker.cc
type lczero (line 26) | namespace lczero {
type training (line 27) | namespace training {
function PickSampledPositions (line 35) | std::vector<uint32_t> PickSampledPositions(int32_t n, double p,
function SampleProbabilisticSequence (line 70) | std::vector<uint32_t> SampleProbabilisticSequence(
function GenerateRunSeed (line 98) | uint32_t GenerateRunSeed() {
function QueueBase (line 173) | QueueBase* ChunkUnpacker::GetOutput(std::string_view name) {
function FramesToProbabilities (line 188) | std::vector<float> FramesToProbabilities(std::span<const FrameType> ...
function StageMetricProto (line 271) | StageMetricProto ChunkUnpacker::FlushMetrics() {
FILE: csrc/loader/stages/chunk_unpacker.h
function namespace (line 22) | namespace lczero {
FILE: csrc/loader/stages/chunk_unpacker_test.cc
type lczero (line 18) | namespace lczero {
type training (line 19) | namespace training {
class PassthroughStage (line 24) | class PassthroughStage : public Stage {
method PassthroughStage (line 26) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 28) | void Start() override {}
method Stop (line 29) | void Stop() override {}
method StageMetricProto (line 30) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 31) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 35) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
class ChunkUnpackerTest (line 47) | class ChunkUnpackerTest : public ::testing::Test {
method SetUp (line 49) | void SetUp() override {
method FrameType (line 56) | FrameType CreateTestFrame(uint32_t version) {
method TrainingChunk (line 64) | TrainingChunk MakeChunk(std::vector<FrameType> frames,
function TEST_F (line 79) | TEST_F(ChunkUnpackerTest, UnpacksSingleFrame) {
function TEST_F (line 95) | TEST_F(ChunkUnpackerTest, UnpacksMultipleFrames) {
function TEST_F (line 126) | TEST_F(ChunkUnpackerTest, UnpacksMultipleChunks) {
function TEST_F (line 160) | TEST_F(ChunkUnpackerTest, HandlesEmptyChunk) {
function TEST_F (line 176) | TEST_F(ChunkUnpackerTest, HandlesQueueClosure) {
function TEST (line 188) | TEST(PickSampledPositionsTest, Deterministic) {
function TEST (line 198) | TEST(PickSampledPositionsTest, FullBucketFirstRound) {
function TEST (line 207) | TEST(PickSampledPositionsTest, DisjointBuckets) {
function TEST (line 227) | TEST(PickSampledPositionsTest, PartialBucketElementsAreReturned) {
function TEST (line 256) | TEST(PickSampledPositionsTest, PartialBucketCompletedSize) {
FILE: csrc/loader/stages/file_path_provider.cc
type lczero (line 24) | namespace lczero {
type training (line 25) | namespace training {
function ShouldSkipName (line 29) | bool ShouldSkipName(std::string_view name) {
function ShouldSkipPathEntry (line 33) | bool ShouldSkipPathEntry(const FilePathProvider::Path& path) {
function StageMetricProto (line 85) | StageMetricProto FilePathProvider::FlushMetrics() {
type inotify_event (line 200) | struct inotify_event
type inotify_event (line 201) | struct inotify_event
type inotify_event (line 219) | struct inotify_event
type epoll_event (line 272) | struct epoll_event
type epoll_event (line 288) | struct epoll_event
type inotify_event (line 319) | struct inotify_event
type inotify_event (line 320) | struct inotify_event
type inotify_event (line 324) | struct inotify_event
type inotify_event (line 331) | struct inotify_event
FILE: csrc/loader/stages/file_path_provider.h
function FilePathProviderMessageType (line 29) | enum class FilePathProviderMessageType {
FILE: csrc/loader/stages/file_path_provider_main.cc
function main (line 15) | int main(int argc, char* argv[]) {
FILE: csrc/loader/stages/file_path_provider_test.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
function FilePathProviderConfig (line 17) | FilePathProviderConfig MakeConfig(const std::filesystem::path& direc...
function RelativeTo (line 24) | std::string RelativeTo(const std::filesystem::path& base,
class FilePathProviderTest (line 31) | class FilePathProviderTest : public ::testing::Test {
method SetUp (line 33) | void SetUp() override {
method TearDown (line 42) | void TearDown() override {
method CreateFile (line 48) | void CreateFile(const std::filesystem::path& path,
method CreateDirectory (line 55) | void CreateDirectory(const std::filesystem::path& path) {
method DrainInitialScan (line 59) | std::vector<std::filesystem::path> DrainInitialScan(
method AwaitNextFile (line 78) | FilePathProvider::File AwaitNextFile(Queue<FilePathProvider::File>...
method FilePathProviderConfig (line 92) | FilePathProviderConfig Config() const { return MakeConfig(test_dir...
function TEST_F (line 97) | TEST_F(FilePathProviderTest, ConstructorCreatesQueue) {
function TEST_F (line 113) | TEST_F(FilePathProviderTest, InitialScanFindsVisibleFiles) {
function TEST_F (line 136) | TEST_F(FilePathProviderTest, InitialScanSkipsHiddenEntries) {
function TEST_F (line 160) | TEST_F(FilePathProviderTest, DetectsNewVisibleFile) {
function TEST_F (line 174) | TEST_F(FilePathProviderTest, DetectsFilesInPreExistingSubdirectory) {
function TEST_F (line 191) | TEST_F(FilePathProviderTest, IgnoresHiddenFileEvents) {
function TEST_F (line 207) | TEST_F(FilePathProviderTest, SkipsHiddenDirectoryRecursion) {
function TEST_F (line 223) | TEST_F(FilePathProviderTest, HandlesEmptyDirectory) {
FILE: csrc/loader/stages/join_stage.cc
type lczero (line 10) | namespace lczero {
type training (line 11) | namespace training {
function StageMetricProto (line 71) | StageMetricProto JoinStage<T>::FlushMetrics() {
class JoinStage<FrameType> (line 86) | class JoinStage<FrameType>
FILE: csrc/loader/stages/join_stage.h
function namespace (line 18) | namespace lczero {
FILE: csrc/loader/stages/join_stage_test.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
class JoinStageTest (line 15) | class JoinStageTest : public ::testing::Test {
method SetUp (line 17) | void SetUp() override { config_.mutable_output()->set_queue_capaci...
method FrameType (line 19) | FrameType CreateTestFrame(uint32_t version) {
function TEST_F (line 30) | TEST_F(JoinStageTest, JoinsTwoInputs) {
function TEST_F (line 63) | TEST_F(JoinStageTest, JoinsThreeInputs) {
function TEST_F (line 98) | TEST_F(JoinStageTest, HandlesEmptyInputs) {
function TEST_F (line 118) | TEST_F(JoinStageTest, FlushesMetrics) {
FILE: csrc/loader/stages/position_sampling.cc
type lczero (line 5) | namespace lczero {
type training (line 6) | namespace training {
function ComputePositionSamplingWeight (line 8) | float ComputePositionSamplingWeight(const FrameType& frame,
FILE: csrc/loader/stages/position_sampling.h
function namespace (line 6) | namespace lczero {
FILE: csrc/loader/stages/shuffling_chunk_pool.cc
type lczero (line 27) | namespace lczero {
type training (line 28) | namespace training {
function QueueBase (line 90) | QueueBase* ShufflingChunkPool::GetOutput(std::string_view name) {
type ShufflingChunkPool::ChunkData (line 478) | struct ShufflingChunkPool::ChunkData {
function ABSL_EXCLUSIVE_LOCKS_REQUIRED (line 659) | ABSL_EXCLUSIVE_LOCKS_REQUIRED(chunk_sources_mutex_) {
function StageMetricProto (line 719) | StageMetricProto ShufflingChunkPool::FlushMetrics() {
FILE: csrc/loader/stages/shuffling_chunk_pool.h
function namespace (line 31) | namespace lczero {
FILE: csrc/loader/stages/shuffling_chunk_pool_test.cc
type lczero (line 22) | namespace lczero {
type training (line 23) | namespace training {
class PassthroughStage (line 28) | class PassthroughStage : public Stage {
method PassthroughStage (line 30) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 32) | void Start() override {}
method Stop (line 33) | void Stop() override {}
method StageMetricProto (line 34) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 35) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 39) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
class MockChunkSource (line 52) | class MockChunkSource : public ChunkSource {
method MockChunkSource (line 54) | MockChunkSource(const std::string& sort_key, size_t chunk_count)
method GetChunkSortKey (line 57) | std::string GetChunkSortKey() const override { return sort_key_; }
method GetChunkCount (line 58) | size_t GetChunkCount() const override { return chunk_count_; }
method GetChunkData (line 60) | std::optional<std::vector<FrameType>> GetChunkData(size_t index) o...
class InvalidChunkSource (line 75) | class InvalidChunkSource : public ChunkSource {
method InvalidChunkSource (line 77) | explicit InvalidChunkSource(std::string sort_key)
method GetChunkSortKey (line 80) | std::string GetChunkSortKey() const override { return sort_key_; }
method GetChunkCount (line 81) | size_t GetChunkCount() const override { return 2; }
method GetChunkData (line 83) | std::optional<std::vector<FrameType>> GetChunkData(size_t index) o...
class ShufflingChunkPoolTest (line 99) | class ShufflingChunkPoolTest : public ::testing::Test {
method SetUp (line 101) | void SetUp() override {
method TearDown (line 107) | void TearDown() override {
method AddMockChunkSourceToQueue (line 113) | void AddMockChunkSourceToQueue(const std::string& sort_key,
method MarkInitialScanComplete (line 123) | void MarkInitialScanComplete() {
method CloseInputQueue (line 130) | void CloseInputQueue() {
method ShufflingChunkPoolConfig (line 134) | ShufflingChunkPoolConfig MakeConfig(int chunk_pool_size,
function TEST_F (line 150) | TEST_F(ShufflingChunkPoolTest, ConstructorCreatesOutputQueue) {
function TEST_F (line 180) | TEST_F(ShufflingChunkPoolTest, HandlesEmptyInputQueue) {
function TEST_F (line 208) | TEST_F(ShufflingChunkPoolTest, FlushMetricsHandlesEmptyChunkSources) {
function TEST_F (line 236) | TEST_F(ShufflingChunkPoolTest, FlushMetricsReportsWindowAndTotalCoun...
function TEST_F (line 282) | TEST_F(ShufflingChunkPoolTest, ProcessesInitialScanChunkSources) {
function TEST_F (line 306) | TEST_F(ShufflingChunkPoolTest, OutputWorkerProducesChunks) {
function TEST_F (line 342) | TEST_F(ShufflingChunkPoolTest, DropsInvalidChunks) {
function TEST_F (line 390) | TEST_F(ShufflingChunkPoolTest, NewChunkSourceProcessing) {
function TEST_F (line 421) | TEST_F(ShufflingChunkPoolTest, ChunkWindowManagement) {
function TEST_F (line 446) | TEST_F(ShufflingChunkPoolTest, ChunkSorting) {
function TEST_F (line 470) | TEST_F(ShufflingChunkPoolTest, StreamShufflerResetWhenExhausted) {
function TEST_F (line 530) | TEST_F(ShufflingChunkPoolTest, HanseMetrics_NoRejection_CacheAndResh...
function TEST_F (line 576) | TEST_F(ShufflingChunkPoolTest, ExplicitClose) {
function TEST_F (line 610) | TEST_F(ShufflingChunkPoolTest, CloseStopsOutputWorkers) {
function TEST_F (line 646) | TEST_F(ShufflingChunkPoolTest, CloseIsIdempotent) {
function TEST_F (line 666) | TEST_F(ShufflingChunkPoolTest, DestructorCallsClose) {
function TEST_F (line 696) | TEST_F(ShufflingChunkPoolTest, InputQueueClosureDoesNotCloseOutputQu...
function TEST_F (line 726) | TEST_F(ShufflingChunkPoolTest, BasicAnchorFunctionality) {
function TEST_F (line 753) | TEST_F(ShufflingChunkPoolTest, ResetAnchor) {
function TEST_F (line 776) | TEST_F(ShufflingChunkPoolTest, AnchorCounterIncrement) {
function TEST_F (line 810) | TEST_F(ShufflingChunkPoolTest, AnchorCounterResetDuringInitialLoad) {
FILE: csrc/loader/stages/shuffling_frame_sampler.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
function StageMetricProto (line 94) | StageMetricProto ShufflingFrameSampler::FlushMetrics() {
FILE: csrc/loader/stages/shuffling_frame_sampler.h
function namespace (line 21) | namespace lczero {
FILE: csrc/loader/stages/shuffling_frame_sampler_test.cc
type lczero (line 10) | namespace lczero {
type training (line 11) | namespace training {
class PassthroughStage (line 16) | class PassthroughStage : public Stage {
method PassthroughStage (line 18) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 20) | void Start() override {}
method Stop (line 21) | void Stop() override {}
method StageMetricProto (line 22) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 23) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 27) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
class ShufflingFrameSamplerTest (line 39) | class ShufflingFrameSamplerTest : public ::testing::Test {
method SetUp (line 41) | void SetUp() override {
method FrameType (line 47) | FrameType CreateTestFrame(uint32_t version) {
function TEST_F (line 59) | TEST_F(ShufflingFrameSamplerTest, OutputsNoFramesWithSmallInput) {
function TEST_F (line 88) | TEST_F(ShufflingFrameSamplerTest, OutputsFramesWithLargeInput) {
function TEST_F (line 124) | TEST_F(ShufflingFrameSamplerTest, HandlesEmptyInput) {
function TEST_F (line 136) | TEST_F(ShufflingFrameSamplerTest, HandlesExactReservoirSize) {
function TEST_F (line 166) | TEST_F(ShufflingFrameSamplerTest, PreservesFrameData) {
FILE: csrc/loader/stages/simple_chunk_extractor.cc
type lczero (line 10) | namespace lczero {
type training (line 11) | namespace training {
function StageMetricProto (line 89) | StageMetricProto SimpleChunkExtractor::FlushMetrics() {
FILE: csrc/loader/stages/simple_chunk_extractor.h
function namespace (line 20) | namespace lczero {
FILE: csrc/loader/stages/simple_chunk_extractor_test.cc
type lczero (line 17) | namespace lczero {
type training (line 18) | namespace training {
class MockChunkSource (line 22) | class MockChunkSource : public ChunkSource {
method MockChunkSource (line 24) | MockChunkSource(std::string sort_key, size_t chunk_count)
method GetChunkSortKey (line 32) | std::string GetChunkSortKey() const override { return sort_key_; }
method GetChunkCount (line 33) | size_t GetChunkCount() const override { return chunk_count_; }
method GetChunkData (line 35) | std::optional<std::vector<FrameType>> GetChunkData(size_t index) o...
class SimpleChunkExtractorTest (line 46) | class SimpleChunkExtractorTest : public ::testing::Test {
method SetUp (line 48) | void SetUp() override {
method TearDown (line 67) | void TearDown() override {
class DummyStage (line 75) | class DummyStage : public Stage {
method DummyStage (line 77) | explicit DummyStage(QueueBase* queue) : queue_(queue) {}
method Start (line 78) | void Start() override {}
method Stop (line 79) | void Stop() override {}
method StageMetricProto (line 80) | StageMetricProto FlushMetrics() override { return {}; }
method QueueBase (line 81) | QueueBase* GetOutput(std::string_view name = "") override {
function TEST_F (line 95) | TEST_F(SimpleChunkExtractorTest, ProcessesSingleSource) {
function TEST_F (line 137) | TEST_F(SimpleChunkExtractorTest, ProcessesMultipleSources) {
function TEST_F (line 179) | TEST_F(SimpleChunkExtractorTest, SkipsNonFileMessages) {
function TEST_F (line 210) | TEST_F(SimpleChunkExtractorTest, MetricsAreRecorded) {
FILE: csrc/loader/stages/stage.cc
type lczero (line 5) | namespace lczero {
type training (line 6) | namespace training {
function QueueBase (line 20) | QueueBase* StageRegistry::GetStageOutput(std::string_view stage_name...
FILE: csrc/loader/stages/stage.h
function namespace (line 17) | namespace lczero {
FILE: csrc/loader/stages/stage_factory.cc
type lczero (line 17) | namespace lczero {
type training (line 18) | namespace training {
function CountStageConfigs (line 22) | int CountStageConfigs(const StageConfig& config) {
function CreateStage (line 37) | std::unique_ptr<Stage> CreateStage(const StageConfig& config) {
FILE: csrc/loader/stages/stage_factory.h
function namespace (line 8) | namespace lczero {
FILE: csrc/loader/stages/stage_factory_test.cc
type lczero (line 7) | namespace lczero {
type training (line 8) | namespace training {
function TEST (line 10) | TEST(StageFactoryTest, CreatesFilePathProviderStage) {
function TEST (line 20) | TEST(StageFactoryTest, ThrowsWhenNoStageConfigSet) {
function TEST (line 26) | TEST(StageFactoryTest, ThrowsWhenMultipleStageConfigsSet) {
FILE: csrc/loader/stages/tensor_generator.cc
type lczero (line 18) | namespace lczero {
type training (line 19) | namespace training {
function TensorTuple (line 85) | TensorTuple TensorGenerator::ConvertFramesToTensors(
function StageMetricProto (line 202) | StageMetricProto TensorGenerator::FlushMetrics() {
FILE: csrc/loader/stages/tensor_generator.h
function namespace (line 21) | namespace lczero {
FILE: csrc/loader/stages/tensor_generator_test.cc
type lczero (line 15) | namespace lczero {
type training (line 16) | namespace training {
class PassthroughStage (line 21) | class PassthroughStage : public Stage {
method PassthroughStage (line 23) | explicit PassthroughStage(Queue<T>* queue) : queue_(queue) {}
method Start (line 25) | void Start() override {}
method Stop (line 26) | void Stop() override {}
method StageMetricProto (line 27) | StageMetricProto FlushMetrics() override { return StageMetricProto...
method QueueBase (line 28) | QueueBase* GetOutput(std::string_view name = "") override {
method SetInputs (line 32) | void SetInputs(absl::Span<QueueBase* const> inputs) override {
class TensorGeneratorTest (line 44) | class TensorGeneratorTest : public ::testing::Test {
method SetUp (line 46) | void SetUp() override {
method FrameType (line 53) | FrameType CreateTestFrame() {
method VerifyTensorTuple (line 91) | void VerifyTensorTuple(const TensorTuple& tensors,
method VerifyTensorData (line 126) | void VerifyTensorData(const TensorTuple& tensors,
function TEST_F (line 200) | TEST_F(TensorGeneratorTest, GeneratesCorrectTensorShapes) {
function TEST_F (line 217) | TEST_F(TensorGeneratorTest, GeneratesCorrectTensorData) {
function TEST_F (line 235) | TEST_F(TensorGeneratorTest, HandlesMultipleBatches) {
function TEST_F (line 270) | TEST_F(TensorGeneratorTest, HandlesDifferentBatchSizes) {
function TEST_F (line 288) | TEST_F(TensorGeneratorTest, HandlesEmptyInput) {
function TEST_F (line 300) | TEST_F(TensorGeneratorTest, VerifiesPlanesConversion) {
function TEST_F (line 339) | TEST_F(TensorGeneratorTest, VerifiesQDConversion) {
FILE: csrc/loader/stages/training_chunk.h
function namespace (line 10) | namespace lczero {
FILE: csrc/tools/dump_chunk_main.cc
type lczero (line 28) | namespace lczero {
type training (line 29) | namespace training {
function DumpChunk (line 36) | void DumpChunk(const std::string& path, int64_t max_entries,
function main (line 80) | int main(int argc, char** argv) {
FILE: csrc/tools/filter_chunks_main.cc
function CollectTarFiles (line 48) | std::vector<fs::path> CollectTarFiles(const fs::path& directory) {
function ParsePlaneValues (line 62) | std::vector<uint64_t> ParsePlaneValues(absl::string_view value_list) {
function PlanesMatch (line 89) | bool PlanesMatch(const FrameType& entry, absl::Span<const uint64_t> expe...
function FindMatchingFrameIndex (line 95) | std::optional<size_t> FindMatchingFrameIndex(
function WriteChunk (line 103) | void WriteChunk(const fs::path& output_dir, absl::string_view base_name,
function ProcessTar (line 138) | void ProcessTar(const fs::path& tar_path, const fs::path& output_dir,
function main (line 170) | int main(int argc, char** argv) {
FILE: csrc/tools/position_weight_stats_main.cc
type WeightedPosition (line 43) | struct WeightedPosition {
function CollectTarFiles (line 48) | std::vector<fs::path> CollectTarFiles(const fs::path& directory) {
function CollectWeights (line 62) | std::vector<float> CollectWeights(const fs::path& tar_path,
function PrintHistogram (line 99) | void PrintHistogram(const std::vector<float>& sorted_weights) {
function PrintPercentiles (line 139) | void PrintPercentiles(const std::vector<float>& sorted_weights) {
function PrintStatistics (line 149) | void PrintStatistics(const std::vector<float>& weights) {
function main (line 165) | int main(int argc, char** argv) {
FILE: csrc/tools/rescore_chunk_main.cc
function ReadChunkFrames (line 44) | std::vector<lczero::V6TrainingData> ReadChunkFrames(const fs::path& path) {
function WriteChunkFrames (line 54) | void WriteChunkFrames(const fs::path& path,
function BuildOutputPath (line 63) | fs::path BuildOutputPath(const fs::path& input_path) {
function main (line 73) | int main(int argc, char** argv) {
FILE: csrc/tools/result_distribution_main.cc
type ChunkResult (line 41) | enum class ChunkResult { kWin, kDraw, kLoss }
type ResultCounts (line 43) | struct ResultCounts {
class CsvWriter (line 49) | class CsvWriter {
method CsvWriter (line 51) | CsvWriter(std::ostream* output, absl::Mutex* mutex)
method Write (line 54) | void Write(absl::string_view basename, const ResultCounts& counts) con...
function DetermineChunkResult (line 79) | std::optional<ChunkResult> DetermineChunkResult(absl::string_view chunk_...
function ResultCounts (line 113) | ResultCounts CountResultsInTar(const fs::path& tar_path) {
function main (line 150) | int main(int argc, char** argv) {
FILE: csrc/tools/startpos_policy_distribution_main.cc
function MatchesStartPosition (line 56) | bool MatchesStartPosition(const FrameType& data) {
function CollectTarFiles (line 62) | std::vector<fs::path> CollectTarFiles(const fs::path& directory) {
function WriteHeader (line 76) | void WriteHeader(std::ostream& output) {
function WriteRow (line 82) | void WriteRow(std::ostream& output, absl::string_view sort_key, size_t i...
function ProcessTarFile (line 91) | void ProcessTarFile(const fs::path& tar_path, std::ostream& output) {
function main (line 119) | int main(int argc, char** argv) {
FILE: csrc/utils/gz.cc
type lczero (line 9) | namespace lczero {
type training (line 10) | namespace training {
function GunzipBuffer (line 12) | std::string GunzipBuffer(std::string_view buffer) {
FILE: csrc/utils/gz.h
function namespace (line 7) | namespace lczero {
FILE: csrc/utils/metrics/exponential_aggregator.h
type class (line 18) | enum class
function GetBucketIndex (line 147) | static size_t GetBucketIndex(TimePeriod period) {
function one_tick (line 318) | auto one_tick = [&](Metric& carry) {
FILE: csrc/utils/metrics/group.h
function namespace (line 7) | namespace lczero {
FILE: csrc/utils/metrics/load_metric.h
function namespace (line 11) | namespace lczero {
function UpdateFrom (line 81) | inline void UpdateFrom(LoadMetricProto& dest, const LoadMetricProto& src) {
function class (line 88) | class LoadMetricPauser {
function DoNotResume (line 101) | void DoNotResume() { should_resume_ = false; }
FILE: csrc/utils/metrics/load_metric_test.cc
type lczero (line 11) | namespace lczero {
type training (line 12) | namespace training {
class LoadMetricTest (line 14) | class LoadMetricTest : public ::testing::Test {
method SetUp (line 18) | void SetUp() override { start_time_ = Clock::now(); }
function TEST_F (line 23) | TEST_F(LoadMetricTest, BasicLoadMetricProto) {
function TEST_F (line 44) | TEST_F(LoadMetricTest, LoadMetricUpdaterBasic) {
function TEST_F (line 75) | TEST_F(LoadMetricTest, LoadMetricUpdaterFlush) {
function TEST_F (line 97) | TEST_F(LoadMetricTest, LoadMetricProtoMerging) {
function TEST_F (line 126) | TEST_F(LoadMetricTest, LoadMetricProtoMoveSemantics) {
function TEST_F (line 166) | TEST_F(LoadMetricTest, LoadUtilizationTracking) {
class LoadMetricProtoIntegrationTest (line 214) | class LoadMetricProtoIntegrationTest : public ::testing::Test {
method SetUp (line 220) | void SetUp() override {
function TEST_F (line 233) | TEST_F(LoadMetricProtoIntegrationTest, RecordMetricsWithUpdater) {
function TEST_F (line 253) | TEST_F(LoadMetricProtoIntegrationTest, MultipleRecordMetrics) {
function TEST_F (line 276) | TEST_F(LoadMetricProtoIntegrationTest, AdvanceTest) {
function TEST_F (line 300) | TEST_F(LoadMetricTest, LoadStartStopReturnValues) {
function main (line 322) | int main(int argc, char** argv) {
FILE: csrc/utils/metrics/printer.h
function namespace (line 8) | namespace lczero {
FILE: csrc/utils/metrics/statistics_metric.h
function namespace (line 7) | namespace lczero {
FILE: csrc/utils/metrics/stats_test.cc
type lczero (line 12) | namespace lczero {
class CounterMetric (line 14) | class CounterMetric {
method CounterMetric (line 16) | CounterMetric() : count_(0) {}
method CounterMetric (line 17) | CounterMetric(int count) : count_(count) {}
method Reset (line 19) | void Reset() { count_ = 0; }
method MergeFrom (line 21) | void MergeFrom(const CounterMetric& other) { count_ += other.count_; }
method Print (line 23) | void Print(MetricPrinter& printer) const {
method count (line 29) | int count() const { return count_; }
method set_count (line 30) | void set_count(int count) { count_ = count; }
class AverageMetric (line 36) | class AverageMetric {
method AverageMetric (line 38) | AverageMetric() : sum_(0), count_(0) {}
method AverageMetric (line 39) | AverageMetric(double sum, int count) : sum_(sum), count_(count) {}
method Reset (line 41) | void Reset() {
method MergeFrom (line 46) | void MergeFrom(const AverageMetric& other) {
method Print (line 51) | void Print(MetricPrinter& printer) const {
method average (line 61) | double average() const { return count_ > 0 ? sum_ / count_ : 0.0; }
method add_sample (line 62) | void add_sample(double value) {
method sum (line 67) | double sum() const { return sum_; }
method count (line 68) | int count() const { return count_; }
class MaxMetric (line 75) | class MaxMetric {
method MaxMetric (line 77) | MaxMetric() : max_value_(0), has_value_(false) {}
method MaxMetric (line 78) | MaxMetric(double max_value) : max_value_(max_value), has_value_(true...
method Reset (line 80) | void Reset() {
method MergeFrom (line 85) | void MergeFrom(const MaxMetric& other) {
method Print (line 94) | void Print(MetricPrinter& printer) const {
method max_value (line 105) | double max_value() const { return max_value_; }
method has_value (line 106) | bool has_value() const { return has_value_; }
method set_value (line 107) | void set_value(double value) {
class OptionalValueMetric (line 120) | class OptionalValueMetric {
method OptionalValueMetric (line 122) | OptionalValueMetric() : value_(std::nullopt) {}
method OptionalValueMetric (line 123) | OptionalValueMetric(int value) : value_(value) {}
method Reset (line 125) | void Reset() { value_ = std::nullopt; }
method MergeFrom (line 127) | void MergeFrom(const OptionalValueMetric& other) {
method Print (line 134) | void Print(MetricPrinter& printer) const {
method value (line 145) | std::optional<int> value() const { return value_; }
method has_value (line 146) | bool has_value() const { return value_.has_value(); }
method set_value (line 147) | void set_value(int value) { value_ = value; }
class MetricGroupTest (line 154) | class MetricGroupTest : public ::testing::Test {
function TEST_F (line 160) | TEST_F(MetricGroupTest, InitialState) {
function TEST_F (line 167) | TEST_F(MetricGroupTest, GetMutable) {
function TEST_F (line 183) | TEST_F(MetricGroupTest, Reset) {
function TEST_F (line 197) | TEST_F(MetricGroupTest, MergeFromGroup) {
function TEST_F (line 218) | TEST_F(MetricGroupTest, MergeFromSingleMetric) {
function TEST_F (line 232) | TEST_F(MetricGroupTest, Print) {
class MetricPrinterTest (line 251) | class MetricPrinterTest : public ::testing::Test {}
function TEST_F (line 253) | TEST_F(MetricPrinterTest, StringMetricPrinter) {
function TEST_F (line 265) | TEST_F(MetricPrinterTest, MultipleGroups) {
function TEST_F (line 280) | TEST_F(MetricPrinterTest, EmptyGroup) {
function TEST_F (line 290) | TEST_F(MetricPrinterTest, SizeTOverload) {
function TEST_F (line 302) | TEST_F(MetricPrinterTest, MetricToStringFunction) {
class ExponentialAggregatorTest (line 310) | class ExponentialAggregatorTest : public ::testing::Test {
method SetUp (line 317) | void SetUp() override {
function TEST_F (line 328) | TEST_F(ExponentialAggregatorTest, RecordMetrics) {
function TEST_F (line 347) | TEST_F(ExponentialAggregatorTest, MultipleUpdatesLiveMetrics) {
function TEST_F (line 364) | TEST_F(ExponentialAggregatorTest, Advance) {
function TEST_F (line 383) | TEST_F(ExponentialAggregatorTest, MultipleAdvances) {
function TEST_F (line 407) | TEST_F(ExponentialAggregatorTest, MultipleAdvancesThreeTicks) {
function TEST_F (line 427) | TEST_F(ExponentialAggregatorTest, AggregationTest) {
function TEST_F (line 514) | TEST_F(ExponentialAggregatorTest, ActualVsRequestedTimeCoverage) {
function TEST_F (line 559) | TEST_F(ExponentialAggregatorTest, ExactDurationTest) {
function main (line 591) | int main(int argc, char** argv) {
FILE: csrc/utils/queue.h
function namespace (line 14) | namespace lczero {
FILE: csrc/utils/queue_test.cc
type lczero (line 14) | namespace lczero {
class QueueTest (line 16) | class QueueTest : public ::testing::Test {
method SetUp (line 18) | void SetUp() override {}
function TEST_F (line 23) | TEST_F(QueueTest, ConstructorCreatesEmptyQueue) {
function TEST_F (line 29) | TEST_F(QueueTest, SinglePutGet) {
function TEST_F (line 42) | TEST_F(QueueTest, MovePutGet) {
function TEST_F (line 56) | TEST_F(QueueTest, MultiplePutGet) {
function TEST_F (line 74) | TEST_F(QueueTest, CircularBufferBehavior) {
function TEST_F (line 95) | TEST_F(QueueTest, BatchPutConstSpan) {
function TEST_F (line 110) | TEST_F(QueueTest, BatchPutMoveSpan) {
function TEST_F (line 129) | TEST_F(QueueTest, BatchPutEmptySpan) {
function TEST_F (line 140) | TEST_F(QueueTest, BatchGet) {
function TEST_F (line 158) | TEST_F(QueueTest, BatchGetZeroCount) {
function TEST_F (line 172) | TEST_F(QueueTest, CapacityOne) {
function TEST_F (line 187) | TEST_F(QueueTest, CreateProducerOnClosedQueue) {
function TEST_F (line 199) | TEST_F(QueueTest, GetOnClosedQueue) {
function TEST_F (line 209) | TEST_F(QueueTest, BatchGetOnClosedQueue) {
function TEST_F (line 221) | TEST_F(QueueTest, SingleProducerSingleConsumer) {
function TEST_F (line 256) | TEST_F(QueueTest, MultipleProducersMultipleConsumers) {
function TEST_F (line 301) | TEST_F(QueueTest, BlockingBehaviorOnFullQueue) {
function TEST_F (line 330) | TEST_F(QueueTest, BlockingBehaviorOnEmptyQueue) {
function TEST_F (line 356) | TEST_F(QueueTest, ProducerDestructionUnblocksWaitingGet) {
function TEST_F (line 389) | TEST_F(QueueTest, GetFromClosedQueueWithElements) {
function TEST_F (line 411) | TEST_F(QueueTest, BatchGetFromClosedQueueWithElements) {
function TEST_F (line 439) | TEST_F(QueueTest, ProducerTokenMechanism) {
function TEST_F (line 472) | TEST_F(QueueTest, ProducerMoveSemantics) {
function TEST_F (line 502) | TEST_F(QueueTest, PutOnClosedQueueThrowsException) {
function TEST_F (line 519) | TEST_F(QueueTest, PutOnClosedQueueAfterProducerDestruction) {
function TEST_F (line 533) | TEST_F(QueueTest, BatchPutOnClosedQueueThrowsException) {
function TEST_F (line 549) | TEST_F(QueueTest, PublicCloseMethod) {
function TEST_F (line 570) | TEST_F(QueueTest, CloseUnblocksWaitingSinglePut) {
function TEST_F (line 602) | TEST_F(QueueTest, CloseUnblocksWaitingBatchPut) {
function TEST_F (line 636) | TEST_F(QueueTest, WaitForRoomAtLeast) {
function TEST_F (line 685) | TEST_F(QueueTest, WaitForRoomAtMost) {
function TEST_F (line 721) | TEST_F(QueueTest, WaitForSizeAtLeast) {
function TEST_F (line 753) | TEST_F(QueueTest, WaitForSizeAtMost) {
function TEST_F (line 796) | TEST_F(QueueTest, WaitFunctionsEdgeCases) {
function TEST_F (line 825) | TEST_F(QueueTest, BatchPutAtCapacityWorks) {
function TEST_F (line 835) | TEST_F(QueueTest, BatchGetAtCapacityWorks) {
function TEST_F (line 851) | TEST_F(QueueTest, LargeRangePutGetGradual) {
function TEST_F (line 878) | TEST_F(QueueTest, LargeRangePutMove) {
function TEST_F (line 908) | TEST_F(QueueTest, LargeRangeGetGradual) {
function TEST_F (line 933) | TEST_F(QueueTest, LargeRangePutGetConcurrent) {
function TEST_F (line 991) | TEST_F(QueueTest, GradualOperationsWithQueueClosure) {
function TEST_F (line 1025) | TEST_F(QueueTest, GetTotalPutCountBasic) {
function TEST_F (line 1048) | TEST_F(QueueTest, GetTotalPutCountBatch) {
function TEST_F (line 1066) | TEST_F(QueueTest, GetTotalPutCountReset) {
function TEST_F (line 1088) | TEST_F(QueueTest, GetTotalPutCountThreadSafe) {
function TEST_F (line 1116) | TEST_F(QueueTest, GetTotalPutCountBatchThreadSafe) {
function TEST_F (line 1145) | TEST_F(QueueTest, GetTotalPutCountWithMoveSemantics) {
function TEST_F (line 1163) | TEST_F(QueueTest, GetTotalPutCountEmptyBatch) {
function TEST_F (line 1180) | TEST_F(QueueTest, DropNewBasicBehavior) {
function TEST_F (line 1205) | TEST_F(QueueTest, DropNewBatchBehavior) {
function TEST_F (line 1228) | TEST_F(QueueTest, DropNewThreadSafety) {
function TEST_F (line 1264) | TEST_F(QueueTest, KeepNewestBasicBehavior) {
function TEST_F (line 1293) | TEST_F(QueueTest, KeepNewestBatchBehavior) {
function TEST_F (line 1317) | TEST_F(QueueTest, KeepNewestLargeBatch) {
function TEST_F (line 1336) | TEST_F(QueueTest, GetTotalGetCountBasic) {
function TEST_F (line 1355) | TEST_F(QueueTest, GetTotalGetCountBatch) {
function TEST_F (line 1375) | TEST_F(QueueTest, GetTotalGetCountReset) {
function TEST_F (line 1390) | TEST_F(QueueTest, GetTotalDropCountBasic) {
function TEST_F (line 1407) | TEST_F(QueueTest, GetTotalDropCountKeepNewest) {
function TEST_F (line 1422) | TEST_F(QueueTest, GetTotalDropCountReset) {
function TEST_F (line 1437) | TEST_F(QueueTest, MaybeGetOnEmptyQueue) {
function TEST_F (line 1443) | TEST_F(QueueTest, MaybeGetOnNonEmptyQueue) {
function TEST_F (line 1456) | TEST_F(QueueTest, MaybeGetMultipleValues) {
function TEST_F (line 1481) | TEST_F(QueueTest, MaybeGetWithMoveOnlyType) {
function TEST_F (line 1494) | TEST_F(QueueTest, MaybeGetUpdatesGetCount) {
function TEST_F (line 1516) | TEST_F(QueueTest, StopTokenCancelsPut) {
function TEST_F (line 1532) | TEST_F(QueueTest, StopTokenCancelsGet) {
function TEST_F (line 1544) | TEST_F(QueueTest, StopTokenCancelsBatchPut) {
function TEST_F (line 1562) | TEST_F(QueueTest, StopTokenCancelsBatchGet) {
function TEST_F (line 1573) | TEST_F(QueueTest, StopTokenCancelsWaitForRoomAtLeast) {
function TEST_F (line 1590) | TEST_F(QueueTest, StopTokenCancelsWaitForSizeAtLeast) {
FILE: csrc/utils/stream_shuffler.cc
type lczero (line 3) | namespace lczero {
type training (line 4) | namespace training {
FILE: csrc/utils/stream_shuffler.h
function namespace (line 11) | namespace lczero {
FILE: csrc/utils/stream_shuffler_test.cc
type lczero (line 10) | namespace lczero {
type training (line 11) | namespace training {
class StreamShufflerTest (line 13) | class StreamShufflerTest : public ::testing::Test {
method SetUp (line 15) | void SetUp() override { shuffler_.SetBucketSize(4); }
function TEST_F (line 20) | TEST_F(StreamShufflerTest, EmptyRangeReturnsNullopt) {
function TEST_F (line 26) | TEST_F(StreamShufflerTest, SingleItemRange) {
function TEST_F (line 37) | TEST_F(StreamShufflerTest, BasicRangeGeneration) {
function TEST_F (line 54) | TEST_F(StreamShufflerTest, HeadAdvancesByBucketMultiples) {
function TEST_F (line 78) | TEST_F(StreamShufflerTest, HeadAdvancesByNonMultiples) {
function TEST_F (line 101) | TEST_F(StreamShufflerTest, TailAdvancesByBucketMultiples) {
function TEST_F (line 129) | TEST_F(StreamShufflerTest, TailAdvancesByNonMultiples) {
function TEST_F (line 157) | TEST_F(StreamShufflerTest, BothBoundsSlideSimultaneously) {
function TEST_F (line 187) | TEST_F(StreamShufflerTest, ComplexSlidingWindow) {
function TEST_F (line 223) | TEST_F(StreamShufflerTest, UniquenessAcrossMultipleBuckets) {
function TEST_F (line 238) | TEST_F(StreamShufflerTest, TailCatchesUpToHead) {
function TEST_F (line 251) | TEST_F(StreamShufflerTest, ResetAllowsIterationRestart) {
FILE: csrc/utils/tensor.h
function namespace (line 12) | namespace lczero {
FILE: csrc/utils/tensor_test.cc
type lczero (line 8) | namespace lczero {
function TEST (line 11) | TEST(TypedTensorTest, ConstructorAndBasicProperties) {
function TEST (line 36) | TEST(TypedTensorTest, PyFormatForDifferentTypes) {
function TEST (line 50) | TEST(TypedTensorTest, ElementAccess) {
function TEST (line 70) | TEST(TypedTensorTest, ConstElementAccess) {
function TEST (line 84) | TEST(TypedTensorTest, SliceAccess) {
function TEST (line 115) | TEST(TypedTensorTest, ConstSliceAccess) {
function TEST (line 129) | TEST(TypedTensorTest, ElementAccessWrongDimensions) {
function TEST (line 136) | TEST(TypedTensorTest, SliceAccessTooManyDimensions) {
function TEST (line 142) | TEST(TypedTensorTest, OneDimensionalTensor) {
FILE: csrc/utils/thread_pool.h
function namespace (line 15) | namespace lczero {
function num_pending_tasks (line 83) | size_t num_pending_tasks() const;
function ThreadPool (line 121) | inline ThreadPool::ThreadPool(size_t initial_threads,
function ThreadPool (line 131) | inline ThreadPool::~ThreadPool() { Shutdown(); }
function WorkerLoop (line 137) | inline void ThreadPool::WorkerLoop() {
function WorkerEntryPoint (line 161) | inline void ThreadPool::WorkerEntryPoint() {
function WaitAll (line 175) | inline void ThreadPool::WaitAll() {
function WaitForAvailableThread (line 182) | inline void ThreadPool::WaitForAvailableThread() {
function WaitForPendingTasksBelow (line 189) | inline void ThreadPool::WaitForPendingTasksBelow(size_t threshold) {
function StartWorkerThread (line 196) | inline void ThreadPool::StartWorkerThread()
function Shutdown (line 216) | inline void ThreadPool::Shutdown() {
FILE: csrc/utils/training_data_printer.cc
type lczero (line 12) | namespace lczero {
type training (line 13) | namespace training {
function PrintFloatArray (line 15) | void PrintFloatArray(const float* data, size_t size, absl::string_vi...
function PrintUint64Array (line 32) | void PrintUint64Array(const uint64_t* data, size_t size, absl::strin...
function DecodeInvarianceInfo (line 49) | std::string DecodeInvarianceInfo(uint8_t invariance_info) {
function TrainingDataToFen (line 59) | std::string TrainingDataToFen(const FrameType& entry) {
function PrintTrainingDataEntry (line 73) | void PrintTrainingDataEntry(const FrameType& entry,
FILE: csrc/utils/training_data_printer.h
function namespace (line 13) | namespace lczero {
FILE: scripts/diff.py
function get_sorted_chunk_ids (line 8) | def get_sorted_chunk_ids(dirs):
function main (line 17) | def main(argv):
FILE: scripts/fixorder.py
function get_sorted_chunk_ids (line 8) | def get_sorted_chunk_ids(dirs):
function main (line 17) | def main(argv):
FILE: scripts/initsplit.py
function get_sorted_chunk_ids (line 8) | def get_sorted_chunk_ids(dirs):
function main (line 17) | def main(argv):
FILE: scripts/inittrainingname.py
function get_sorted_chunk_ids (line 8) | def get_sorted_chunk_ids(dirs):
function main (line 22) | def main(argv):
FILE: scripts/pack.py
function get_uncompressed_size (line 15) | def get_uncompressed_size(filename):
function get_sorted_chunk_ids (line 21) | def get_sorted_chunk_ids(dirs):
function pack (line 30) | def pack(ids):
function main (line 50) | def main():
FILE: scripts/purge.py
function get_sorted_chunk_ids (line 8) | def get_sorted_chunk_ids(dirs):
function main (line 17) | def main(argv):
FILE: scripts/shuffle.py
function split (line 16) | def split(a, n):
function positions (line 21) | def positions(chunk):
function shuffle (line 30) | def shuffle(files):
FILE: scripts/unpack.py
function unpack (line 13) | def unpack(filepath):
function main (line 42) | def main():
FILE: src/lczero_training/_lczero_training.pyi
class TensorBase (line 11) | class TensorBase:
method shape (line 12) | def shape(self) -> List[int]: ...
method strides (line 13) | def strides(self) -> List[int]: ...
method element_size (line 14) | def element_size(self) -> int: ...
method py_format (line 15) | def py_format(self) -> str: ...
class DataLoader (line 17) | class DataLoader:
method __init__ (line 18) | def __init__(self, config: DataLoaderConfig | bytes) -> None: ...
method add_stages (line 19) | def add_stages(self, config: DataLoaderConfig | bytes) -> None: ...
method send_control_message (line 20) | def send_control_message(
method start (line 23) | def start(self) -> None: ...
method get_next (line 24) | def get_next(self, alias: str = "") -> Tuple[np.ndarray, ...]: ...
method maybe_get_next (line 25) | def maybe_get_next(
method stop (line 28) | def stop(self) -> None: ...
method get_bucket_metrics (line 29) | def get_bucket_metrics(
method get_aggregate_ending_now (line 32) | def get_aggregate_ending_now(
FILE: src/lczero_training/commands/backfill_metrics.py
function _build_parser (line 9) | def _build_parser() -> argparse.ArgumentParser:
function main (line 45) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/common.py
function configure_root_logging (line 13) | def configure_root_logging(level: int | str = logging.INFO) -> None:
function parse_log_level (line 32) | def parse_log_level(level: int | str) -> int:
function add_logging_arguments (line 46) | def add_logging_arguments(parser: argparse.ArgumentParser) -> None:
FILE: src/lczero_training/commands/daemon.py
function _build_parser (line 9) | def _build_parser() -> argparse.ArgumentParser:
function main (line 15) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/dataloader_viz.py
function _build_parser (line 11) | def _build_parser() -> argparse.ArgumentParser:
function main (line 30) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/describe_training.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 34) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/jax2leela.py
function _build_parser (line 19) | def _build_parser() -> argparse.ArgumentParser:
function jax2leela (line 49) | def jax2leela(
function main (line 116) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/leela2jax.py
function _build_parser (line 5) | def _build_parser() -> argparse.ArgumentParser:
function main (line 50) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/migrate_checkpoint.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 64) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/overfit.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 39) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/test_dataloader.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 34) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/train.py
function _build_parser (line 27) | def _build_parser() -> argparse.ArgumentParser:
function train (line 38) | def train(config_filename: str) -> None:
function main (line 124) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/training_eval.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 59) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/training_init.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 63) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/tui.py
function _build_parser (line 11) | def _build_parser() -> argparse.ArgumentParser:
function _amain (line 17) | async def _amain(args: argparse.Namespace) -> None:
function main (line 22) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/tune_lr.py
function _build_parser (line 8) | def _build_parser() -> argparse.ArgumentParser:
function main (line 76) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/commands/weights_tool.py
function _build_parser (line 11) | def _build_parser() -> argparse.ArgumentParser:
function main (line 49) | def main(argv: list[str] | None = None) -> int:
FILE: src/lczero_training/convert/jax_to_leela.py
class JaxToLeela (line 19) | class JaxToLeela(LeelaPytreeWeightsVisitor):
method embedding_block (line 20) | def embedding_block(
method tensor (line 33) | def tensor(
method encoder_tower (line 55) | def encoder_tower(
class LeelaExportOptions (line 64) | class LeelaExportOptions:
function jax_to_leela (line 71) | def jax_to_leela(
function _split_version (line 96) | def _split_version(version_str: str) -> tuple[int, int, int]:
function _make_format (line 102) | def _make_format() -> net_pb2.Format:
FILE: src/lczero_training/convert/leela_pytree_visitor.py
class LeelaPytreeWeightsVisitor (line 9) | class LeelaPytreeWeightsVisitor:
method __init__ (line 10) | def __init__(self, nnx_state: nnx.State, leela_net: net_pb2.Net) -> None:
method run (line 14) | def run(self) -> None:
method embedding_block (line 32) | def embedding_block(
method encoder_tower (line 63) | def encoder_tower(
method encoder_block (line 81) | def encoder_block(
method mha (line 89) | def mha(self, nnx_dict: nnx.State, weights: net_pb2.Weights.MHA) -> None:
method smolgen (line 96) | def smolgen(
method layernorm (line 105) | def layernorm(
method policy_heads (line 114) | def policy_heads(
method policy_head (line 130) | def policy_head(
method value_head (line 139) | def value_head(
method movesleft_head (line 156) | def movesleft_head(
method ffn (line 163) | def ffn(self, nnx_dict: nnx.State, ffn: net_pb2.Weights.FFN) -> None:
method matmul (line 167) | def matmul(
method tensor (line 179) | def tensor(
FILE: src/lczero_training/convert/leela_to_jax.py
class LeelaImportOptions (line 25) | class LeelaImportOptions:
function fix_older_weights_file (line 30) | def fix_older_weights_file(file: net_pb2.Net) -> None:
class LeelaToJax (line 84) | class LeelaToJax(LeelaPytreeWeightsVisitor):
method embedding_block (line 85) | def embedding_block(
method tensor (line 96) | def tensor(
function leela_to_jax (line 113) | def leela_to_jax(
function leela_to_jax_files (line 130) | def leela_to_jax_files(
FILE: src/lczero_training/convert/leela_to_modelconfig.py
function _defaultactivation_to_activation (line 4) | def _defaultactivation_to_activation(
function leela_to_modelconfig (line 13) | def leela_to_modelconfig(
FILE: src/lczero_training/daemon/daemon.py
class TrainingDaemon (line 20) | class TrainingDaemon:
method __init__ (line 25) | def __init__(self, memory_profile_dir: str | None = None) -> None:
method _setup_logging (line 44) | def _setup_logging(self) -> None:
method _setup_signal_handling (line 56) | def _setup_signal_handling(self) -> None:
method _signal_handler_thread (line 62) | def _signal_handler_thread(self) -> None:
method _shutdown (line 67) | def _shutdown(self, signum: int) -> None:
method _metrics_main (line 72) | async def _metrics_main(self) -> None:
method _metrics_task (line 76) | async def _metrics_task(self) -> None:
method run (line 116) | def run(self) -> None:
method on_start_training (line 128) | def on_start_training(self, payload: StartTrainingPayload) -> None:
method on_start_training_immediately (line 131) | def on_start_training_immediately(
FILE: src/lczero_training/daemon/metrics.py
class CachedBatch (line 31) | class CachedBatch:
function load_batch_from_npz (line 38) | def load_batch_from_npz(npz_filename: str) -> BatchTuple:
class _TrainingBatchMetric (line 59) | class _TrainingBatchMetric(_Metric):
method __init__ (line 62) | def __init__(self, config: MetricConfig, logger: TensorboardLogger):
method log (line 70) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
class _EvaluatingMetric (line 74) | class _EvaluatingMetric(_Metric, ABC):
method __init__ (line 77) | def __init__(
method get_batch (line 89) | def get_batch(self) -> BatchTuple:
method log (line 92) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
method _evaluate (line 97) | def _evaluate(
function _make_eval_jit (line 122) | def _make_eval_jit(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> _Eval...
function evaluate_batch (line 144) | def evaluate_batch(
class _DataLoaderMetric (line 176) | class _DataLoaderMetric(_EvaluatingMetric):
method __init__ (line 179) | def __init__(
method get_batch (line 195) | def get_batch(self) -> BatchTuple:
method log (line 198) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
class _NpzMetric (line 216) | class _NpzMetric(_EvaluatingMetric):
method __init__ (line 219) | def __init__(
method get_batch (line 231) | def get_batch(self) -> BatchTuple:
method log (line 234) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
class Metrics (line 245) | class Metrics:
method __init__ (line 248) | def __init__(
method on_step (line 294) | def on_step(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> ...
method close (line 304) | def close(self) -> None:
FILE: src/lczero_training/daemon/metrics_base.py
class _Metric (line 12) | class _Metric(ABC):
method __init__ (line 15) | def __init__(self, config: MetricConfig, logger: TensorboardLogger):
method should_log (line 19) | def should_log(
method log (line 29) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
FILE: src/lczero_training/daemon/pipeline.py
function _read_config_file (line 44) | def _read_config_file(config_filepath: str) -> RootConfig:
function _make_dataloader (line 53) | def _make_dataloader(config: DataLoaderConfig) -> DataLoader:
function _configure_file_logging (line 58) | def _configure_file_logging(config: RootConfig) -> None:
function _log_jax_system_info (line 73) | def _log_jax_system_info() -> None:
class _TrainingCycleState (line 93) | class _TrainingCycleState:
class TrainingPipeline (line 106) | class TrainingPipeline:
method __init__ (line 116) | def __init__(
method start_training_immediately (line 211) | def start_training_immediately(self) -> None:
method run (line 217) | def run(self) -> None:
method _export_network (line 239) | def _export_network(self) -> bytes | None:
method _save_network (line 263) | def _save_network(self, network_bytes: bytes) -> None:
method _upload_network (line 277) | def _upload_network(self, network_bytes: bytes) -> None:
method _step_hook (line 319) | def _step_hook(self, hook_data: StepHookData) -> None:
method _train_one_network (line 325) | def _train_one_network(self) -> None:
method _save_checkpoint (line 360) | def _save_checkpoint(self) -> None:
method stop (line 368) | def stop(self) -> None:
method get_data_loader (line 373) | def get_data_loader(self) -> DataLoader:
method _wait_for_chunks (line 376) | def _wait_for_chunks(self) -> None:
method get_training_schedule_data (line 397) | def get_training_schedule_data(
method _send_chunk_pool_control (line 434) | def _send_chunk_pool_control(
method _reset_chunk_anchor (line 443) | def _reset_chunk_anchor(self) -> tuple[str, int]:
method _chunks_since_anchor (line 452) | def _chunks_since_anchor(self) -> int:
method _set_chunk_anchor (line 460) | def _set_chunk_anchor(self, anchor: str) -> None:
method _load_config (line 465) | def _load_config(self, config_filepath: str) -> RootConfig:
FILE: src/lczero_training/daemon/protocol/communicator.py
function _to_serializable (line 17) | def _to_serializable(obj: Any) -> Any:
function _unwrap_optional (line 39) | def _unwrap_optional(t: Any) -> Any:
function _is_protobuf (line 47) | def _is_protobuf(cls: type) -> bool:
function _from_serializable (line 55) | def _from_serializable(cls: type, data: Any) -> Any:
class Communicator (line 88) | class Communicator:
method __init__ (line 89) | def __init__(
method send (line 104) | def send(self, payload_instance: Any) -> None:
method _dispatch (line 123) | def _dispatch(self, line: str) -> None:
method run (line 140) | def run(self) -> None:
class AsyncCommunicator (line 152) | class AsyncCommunicator:
method __init__ (line 153) | def __init__(
method send (line 175) | async def send(self, payload_instance: Any) -> None:
method _dispatch (line 196) | async def _dispatch(self, line: str) -> None:
method run (line 213) | async def run(self) -> None:
FILE: src/lczero_training/daemon/protocol/messages.py
class TrainingStage (line 13) | class TrainingStage(Enum):
class TrainingScheduleData (line 19) | class TrainingScheduleData:
class StartTrainingPayload (line 36) | class StartTrainingPayload:
class StartTrainingImmediatelyPayload (line 42) | class StartTrainingImmediatelyPayload:
class TrainingStatusPayload (line 51) | class TrainingStatusPayload:
FILE: src/lczero_training/daemon/protocol/registry.py
function register (line 12) | def register(event_type: str) -> Callable[[type], type]:
FILE: src/lczero_training/daemon/rms_metrics.py
function compute_rms (line 18) | def compute_rms(state_subtree: nnx.State) -> jax.Array:
function extract_attention_components (line 26) | def extract_attention_components(model: LczeroModel) -> dict[str, Any]:
function collect_rms_metrics (line 58) | def collect_rms_metrics(model: LczeroModel) -> dict[str, Any]:
class _RmsMetric (line 103) | class _RmsMetric(_Metric):
method __init__ (line 106) | def __init__(self, config: MetricConfig, logger: TensorboardLogger):
method log (line 109) | def log(self, hook_data: StepHookData, graphdef: nnx.GraphDef) -> None:
FILE: src/lczero_training/dataloader/__init__.py
function make_dataloader (line 10) | def make_dataloader(config: DataLoaderConfig) -> DataLoader:
FILE: src/lczero_training/model/embedding.py
class Embedding (line 11) | class Embedding(nnx.Module):
method __init__ (line 14) | def __init__(
method __call__ (line 54) | def __call__(self, x: jax.Array) -> jax.Array:
class MaGating (line 70) | class MaGating(nnx.Module):
method __init__ (line 73) | def __init__(self, feature_shape: tuple[int, ...], *, rngs: nnx.Rngs):
method __call__ (line 81) | def __call__(self, x: jax.Array) -> jax.Array:
class Gating (line 85) | class Gating(nnx.Module):
method __init__ (line 86) | def __init__(
method __call__ (line 99) | def __call__(self, inputs: jax.Array) -> jax.Array:
FILE: src/lczero_training/model/encoder.py
class EncoderTower (line 15) | class EncoderTower(nnx.Module):
method __init__ (line 16) | def __init__(
method __call__ (line 49) | def __call__(self, x: jax.Array) -> jax.Array:
class EncoderBlock (line 53) | class EncoderBlock(nnx.Module):
method __init__ (line 56) | def __init__(
method __call__ (line 87) | def __call__(self, x: jax.Array) -> jax.Array:
class MultiHeadAttention (line 94) | class MultiHeadAttention(nnx.Module):
method __init__ (line 97) | def __init__(
method __call__ (line 153) | def __call__(self, x: jax.Array) -> jax.Array:
class Smolgen (line 180) | class Smolgen(nnx.Module):
method __init__ (line 183) | def __init__(
method __call__ (line 218) | def __call__(self, x: jax.Array) -> jax.Array:
FILE: src/lczero_training/model/loss_function.py
function _compute_q_from_wdl (line 24) | def _compute_q_from_wdl(wdl_logits: jax.Array) -> jax.Array:
class LossBase (line 31) | class LossBase:
method __init__ (line 32) | def __init__(
method __call__ (line 46) | def __call__(
class RegularizationLoss (line 54) | class RegularizationLoss:
method __init__ (line 57) | def __init__(self, config: RegularizationLossConfig) -> None:
method __call__ (line 62) | def __call__(self, model: LczeroModel) -> jax.Array:
class LczeroLoss (line 77) | class LczeroLoss:
method __init__ (line 85) | def __init__(self, config: LossConfig) -> None:
method __call__ (line 131) | def __call__(
class ValueLoss (line 181) | class ValueLoss(LossBase):
method __init__ (line 182) | def __init__(self, config: ValueLossConfig) -> None:
method __call__ (line 186) | def __call__(
class PolicyLoss (line 209) | class PolicyLoss(LossBase):
method __init__ (line 210) | def __init__(self, config: PolicyLossConfig):
method _apply_temperature_and_normalize (line 235) | def _apply_temperature_and_normalize(
method _compute_optimistic_weight (line 251) | def _compute_optimistic_weight(
method __call__ (line 278) | def __call__(
class MovesLeftLoss (line 323) | class MovesLeftLoss(LossBase):
method __init__ (line 324) | def __init__(self, config: MovesLeftLossConfig) -> None:
method __call__ (line 328) | def __call__(
class ValueErrorLoss (line 351) | class ValueErrorLoss(LossBase):
method __init__ (line 352) | def __init__(self, config: ValueErrorLossConfig) -> None:
method __call__ (line 357) | def __call__(
class ValueCategoricalLoss (line 386) | class ValueCategoricalLoss(LossBase):
method __init__ (line 387) | def __init__(self, config: ValueCategoricalLossConfig) -> None:
method __call__ (line 391) | def __call__(
FILE: src/lczero_training/model/model.py
class ModelPrediction (line 21) | class ModelPrediction:
class LczeroModel (line 35) | class LczeroModel(nnx.Module):
method __init__ (line 36) | def __init__(self, config: model_config_pb2.ModelConfig, *, rngs: nnx....
method __call__ (line 109) | def __call__(self, x: jax.Array) -> ModelPrediction:
FILE: src/lczero_training/model/movesleft_head.py
class MovesLeftHead (line 9) | class MovesLeftHead(nnx.Module):
method __init__ (line 10) | def __init__(
method __call__ (line 36) | def __call__(self, x: jax.Array) -> jax.Array:
FILE: src/lczero_training/model/policy_head.py
class PolicyHead (line 13) | class PolicyHead(nnx.Module):
method __init__ (line 14) | def __init__(
method __call__ (line 58) | def __call__(self, x: jax.Array) -> jax.Array:
FILE: src/lczero_training/model/shared.py
class Ffn (line 10) | class Ffn(nnx.Module):
method __init__ (line 11) | def __init__(
method __call__ (line 40) | def __call__(self, x: jax.Array) -> jax.Array:
FILE: src/lczero_training/model/utils.py
function get_activation (line 11) | def get_activation(
function get_dtype (line 26) | def get_dtype(dtype: XlaShapeProto.Type) -> jnp.dtype:
FILE: src/lczero_training/model/value_head.py
class ValueHead (line 11) | class ValueHead(nnx.Module):
method __init__ (line 12) | def __init__(
method __call__ (line 50) | def __call__(
method predict (line 66) | def predict(self, x: jax.Array) -> jax.Array:
FILE: src/lczero_training/tests/test_protobuf.py
function test_protobuf_import (line 4) | def test_protobuf_import() -> None:
function test_protobuf_functionality (line 29) | def test_protobuf_functionality() -> None:
FILE: src/lczero_training/tests/test_protocol_registry.py
function clear_registry (line 16) | def clear_registry() -> Any:
function test_basic_registration (line 25) | def test_basic_registration() -> None:
function test_duplicate_event_type (line 39) | def test_duplicate_event_type() -> None:
function test_duplicate_class (line 57) | def test_duplicate_class() -> None:
function test_non_class_registration (line 74) | def test_non_class_registration() -> None:
function test_multiple_registrations (line 83) | def test_multiple_registrations() -> None:
function test_registry_persistence (line 115) | def test_registry_persistence() -> None:
FILE: src/lczero_training/tests/test_weights_tool.py
function test_weights_arithmetic (line 13) | def test_weights_arithmetic() -> None:
function test_policy_head_replacement (line 54) | def test_policy_head_replacement() -> None:
function test_policy_head_map_assignment (line 101) | def test_policy_head_map_assignment() -> None:
function test_noop_arithmetic (line 158) | def test_noop_arithmetic() -> None:
function test_list_item_assignment (line 195) | def test_list_item_assignment() -> None:
FILE: src/lczero_training/tools/weight_codecs.py
function decode_linear16 (line 8) | def decode_linear16(
function encode_linear16 (line 18) | def encode_linear16(arr: np.ndarray) -> tuple[bytes, float, float]:
function decode_float16 (line 34) | def decode_float16(params: bytes, shape: tuple[int, ...]) -> np.ndarray:
function encode_float16 (line 41) | def encode_float16(arr: np.ndarray) -> tuple[bytes, float, float]:
function decode_bfloat16 (line 47) | def decode_bfloat16(params: bytes, shape: tuple[int, ...]) -> np.ndarray:
function encode_bfloat16 (line 55) | def encode_bfloat16(arr: np.ndarray) -> tuple[bytes, float, float]:
function decode_layer (line 63) | def decode_layer(
function encode_layer (line 87) | def encode_layer(arr: np.ndarray, encoding: int) -> tuple[bytes, float, ...
FILE: src/lczero_training/tools/weight_wrappers.py
class LayerWrapper (line 14) | class LayerWrapper:
method __init__ (line 19) | def __init__(
method value (line 33) | def value(self) -> np.ndarray:
method value (line 44) | def value(self, arr: np.ndarray) -> None:
method commit (line 49) | def commit(self, encoding: int) -> None:
method __add__ (line 63) | def __add__(self, other: "LayerWrapper") -> "LayerWrapper":
method __sub__ (line 70) | def __sub__(self, other: "LayerWrapper") -> "LayerWrapper":
method __mul__ (line 77) | def __mul__(self, scalar: float) -> "LayerWrapper":
method __rmul__ (line 84) | def __rmul__(self, scalar: float) -> "LayerWrapper":
class ListWrapper (line 88) | class ListWrapper:
method __init__ (line 93) | def __init__(self, proto_list: Any, parent: "NetWrapper") -> None:
method __len__ (line 102) | def __len__(self) -> int:
method __getitem__ (line 105) | def __getitem__(self, idx: int) -> Any:
method __setitem__ (line 111) | def __setitem__(self, idx: int, value: Any) -> None:
method __iter__ (line 132) | def __iter__(self) -> Iterator[Any]:
class NetWrapper (line 137) | class NetWrapper:
method __init__ (line 142) | def __init__(
method _detect_encoding (line 157) | def _detect_encoding(self) -> int:
method __getattr__ (line 163) | def __getattr__(self, name: str) -> Any:
method _wrap_field (line 180) | def _wrap_field(self, value: Any) -> Any:
method __setattr__ (line 191) | def __setattr__(self, name: str, value: Any) -> None:
method save (line 213) | def save(self, path: str, encoding: int | None = None) -> None:
method _commit_all (line 228) | def _commit_all(self, encoding: int) -> None:
method __add__ (line 242) | def __add__(self, other: "NetWrapper") -> "NetWrapper":
method __sub__ (line 253) | def __sub__(self, other: "NetWrapper") -> "NetWrapper":
method __mul__ (line 264) | def __mul__(self, scalar: float) -> "NetWrapper":
method __rmul__ (line 275) | def __rmul__(self, scalar: float) -> "NetWrapper":
method _add_weights (line 278) | def _add_weights(self, lhs: "NetWrapper", rhs: "NetWrapper") -> None:
method _sub_weights (line 316) | def _sub_weights(self, lhs: "NetWrapper", rhs: "NetWrapper") -> None:
method _mul_weights (line 354) | def _mul_weights(self, source: "NetWrapper", scalar: float) -> None:
FILE: src/lczero_training/tools/weights_tool.py
function load_weights (line 10) | def load_weights(path: str) -> NetWrapper:
function save_weights (line 24) | def save_weights(
FILE: src/lczero_training/training/backfill_metrics.py
function _load_config (line 28) | def _load_config(config_path: str) -> RootConfig:
function _validate_and_get_metrics (line 36) | def _validate_and_get_metrics(
function _load_and_migrate_checkpoint (line 64) | def _load_and_migrate_checkpoint(
function backfill_metrics (line 75) | def backfill_metrics(
FILE: src/lczero_training/training/dataloader_probe.py
function _stop_loader (line 18) | def _stop_loader(loader: DataLoader) -> None:
function _store_batches (line 23) | def _store_batches(path: str, batches: list) -> None:
function probe_dataloader (line 33) | def probe_dataloader(
FILE: src/lczero_training/training/describe.py
function describe (line 17) | def describe(
FILE: src/lczero_training/training/eval.py
class DiffRecord (line 50) | class DiffRecord:
function _tensor_to_list (line 61) | def _tensor_to_list(obj: Any) -> Any:
function _bin_counts (line 73) | def _bin_counts(values: np.ndarray) -> Dict[str, Any]:
function _format_bound (line 88) | def _format_bound(value: float) -> str:
function _format_stats (line 95) | def _format_stats(stats: Dict[str, Any]) -> str:
function _collect_diff_statistics (line 108) | def _collect_diff_statistics(
class Dumper (line 125) | class Dumper:
method __init__ (line 128) | def __init__(
method dump_tensors (line 142) | def dump_tensors(self, tensors: dict, prefix: str) -> None:
method dump_structured (line 159) | def dump_structured(self, batch: dict, outputs: dict, losses: dict) ->...
method _dump_to_shelve (line 176) | def _dump_to_shelve(self, key: str, data: dict) -> None:
method _dump_to_json (line 182) | def _dump_to_json(self, key: str, data: dict) -> None:
method close (line 194) | def close(self) -> None:
class OnnxComparator (line 199) | class OnnxComparator:
method __init__ (line 202) | def __init__(self, onnx_model_path: str):
method compare (line 223) | def compare(
method log_summary (line 262) | def log_summary(self) -> None:
method _log_diff_stats (line 280) | def _log_diff_stats(
method _align_onnx_outputs (line 328) | def _align_onnx_outputs(
method _reshape_output (line 375) | def _reshape_output(
class Evaluation (line 393) | class Evaluation:
method __init__ (line 396) | def __init__(self, loss_fn: LczeroLoss):
method run (line 399) | def run(
method _process_sample (line 427) | def _process_sample(
method _loss_for_grad (line 502) | def _loss_for_grad(
method _model_for_output (line 508) | def _model_for_output(
function from_dataloader (line 514) | def from_dataloader(
function _load_model_from_checkpoint (line 522) | def _load_model_from_checkpoint(config: RootConfig) -> LczeroModel:
function _get_dataloader_config (line 545) | def _get_dataloader_config(
function eval (line 564) | def eval(
FILE: src/lczero_training/training/init.py
function _load_lc0_model_state (line 25) | def _load_lc0_model_state(
function init (line 62) | def init(
FILE: src/lczero_training/training/lr_schedule.py
function _create_rule_fn (line 9) | def _create_rule_fn(rule: LrSchedule) -> Callable:
function make_lr_schedule (line 103) | def make_lr_schedule(schedules: Sequence[LrSchedule]) -> optax.Schedule:
FILE: src/lczero_training/training/migrate_checkpoint.py
function _str_to_key_path (line 14) | def _str_to_key_path(path_str: str) -> tuple[str, ...]:
function _load_new_state (line 18) | def _load_new_state(
function load_checkpoint (line 29) | def load_checkpoint(
function get_checkpoint_steps (line 52) | def get_checkpoint_steps(
function _load_old_state (line 83) | def _load_old_state(
function load_migration_rules (line 89) | def load_migration_rules(rules_file: str | None) -> List[Tuple[Any, Any]]:
function _format_value (line 121) | def _format_value(value: Any) -> str:
function _format_path_diff (line 127) | def _format_path_diff(
class Migration (line 143) | class Migration:
method __init__ (line 144) | def __init__(self, old_state: Any, new_state: Any):
method _apply_move_rule (line 171) | def _apply_move_rule(
method _apply_ignore_rule (line 202) | def _apply_ignore_rule(self, from_path: Tuple[str, ...]) -> None:
method _apply_keep_rule (line 210) | def _apply_keep_rule(self, to_path: Tuple[str, ...]) -> None:
method apply_rules (line 218) | def apply_rules(self, rules: List[Tuple[Any, Any]]) -> None:
method run (line 227) | def run(self, rules: List[Tuple[Any, Any]]) -> Any:
function _save_checkpoint (line 258) | def _save_checkpoint(
function _dump_paths (line 291) | def _dump_paths(paths: Iterable[Tuple[str, ...]], field: str) -> None:
function migrate_checkpoint (line 299) | def migrate_checkpoint(
FILE: src/lczero_training/training/optimizer.py
function update_optimizer_step (line 17) | def update_optimizer_step(
function make_gradient_transformation (line 37) | def make_gradient_transformation(
FILE: src/lczero_training/training/overfit.py
function _stop_loader (line 32) | def _stop_loader(loader: DataLoader) -> None:
function _prepare_batch (line 37) | def _prepare_batch(batch_tuple: tuple) -> TrainingBatch:
function _make_eval_step (line 46) | def _make_eval_step(graphdef: nnx.GraphDef, loss_fn: LczeroLoss) -> Any:
function overfit (line 69) | def overfit(
FILE: src/lczero_training/training/state.py
class TrainingSample (line 27) | class TrainingSample:
class TrainingBatch (line 51) | class TrainingBatch:
method from_tuple (line 71) | def from_tuple(
class JitTrainingState (line 87) | class JitTrainingState:
method replace (line 97) | def replace(self, **changes: Any) -> "JitTrainingState":
class TrainingState (line 103) | class TrainingState:
method replace (line 109) | def replace(self, **changes: Any) -> "TrainingState":
method with_updated_step (line 113) | def with_updated_step(self, step: int) -> "TrainingState":
method new_from_config (line 128) | def new_from_config(
FILE: src/lczero_training/training/tensorboard.py
function _to_ndarray (line 18) | def _to_ndarray(value: Any) -> np.ndarray:
function _to_scalar (line 25) | def _to_scalar(value: Any) -> float | None:
function _flatten_metrics (line 36) | def _flatten_metrics(
function _to_step (line 51) | def _to_step(step: Any) -> int:
class TensorboardLogger (line 55) | class TensorboardLogger:
method __init__ (line 58) | def __init__(self, logdir: str) -> None:
method log (line 61) | def log(self, step: int, metrics: MetricsDict) -> None:
method close (line 68) | def close(self) -> None:
FILE: src/lczero_training/training/test_lr_schedule.py
function _sched (line 10) | def _sched(
function _val (line 16) | def _val(s: Callable[[jnp.ndarray], jnp.ndarray], t: int | float) -> float:
function test_rule_selection_by_starting_step (line 20) | def test_rule_selection_by_starting_step() -> None:
function test_default_constant_transition_and_tail (line 38) | def test_default_constant_transition_and_tail() -> None:
function test_linear_then_hold (line 51) | def test_linear_then_hold() -> None:
function test_looping_constant_segments (line 66) | def test_looping_constant_segments() -> None:
function test_zero_duration_is_skipped (line 80) | def test_zero_duration_is_skipped() -> None:
function test_chain_zero_durations_then_linear (line 97) | def test_chain_zero_durations_then_linear() -> None:
function test_cosine (line 116) | def test_cosine() -> None:
function test_before_first_rule_uses_earliest_first_lr (line 130) | def test_before_first_rule_uses_earliest_first_lr() -> None:
FILE: src/lczero_training/training/training.py
class StepHookData (line 30) | class StepHookData:
function from_dataloader (line 45) | def from_dataloader(
class Training (line 52) | class Training:
method __init__ (line 61) | def __init__(
method _swa_tree_map (line 164) | def _swa_tree_map(
method update_swa (line 174) | def update_swa(
method maybe_update_swa (line 206) | def maybe_update_swa(
method _validate_and_prepare_batch (line 227) | def _validate_and_prepare_batch(
method _log_step_metrics (line 258) | def _log_step_metrics(
method _execute_step_hook (line 275) | def _execute_step_hook(
method run (line 295) | def run(
FILE: src/lczero_training/training/tune_lr.py
function _prepare_batch (line 28) | def _prepare_batch(batch_tuple: tuple) -> Dict:
function _make_optimizer_with_schedule (line 37) | def _make_optimizer_with_schedule(
function _make_eval_step (line 68) | def _make_eval_step(
function _plot_results (line 87) | def _plot_results(results: List[Tuple[float, float]], plot_output: str) ...
function tune_lr (line 102) | def tune_lr(
FILE: src/lczero_training/training/utils.py
function make_weights_mask (line 8) | def make_weights_mask(
FILE: src/lczero_training/tui/app.py
class HeaderBar (line 33) | class HeaderBar(Static):
method compose (line 36) | def compose(self) -> ComposeResult:
class JAXTrainingPane (line 41) | class JAXTrainingPane(Static):
method compose (line 44) | def compose(self) -> ComposeResult:
class TrainingTuiApp (line 53) | class TrainingTuiApp(App):
method add_arguments (line 66) | def add_arguments(parser: argparse.ArgumentParser) -> None:
method __init__ (line 102) | def __init__(self, args: Optional[argparse.Namespace] = None) -> None:
method on_load (line 122) | async def on_load(self) -> None:
method compose (line 160) | def compose(self) -> ComposeResult:
method _monitor_daemon_process (line 183) | async def _monitor_daemon_process(self) -> None:
method on_mount (line 198) | def on_mount(self) -> None:
method _send_start_training (line 204) | async def _send_start_training(self) -> None:
method _command_start_training_immediately (line 209) | async def _command_start_training_immediately(self) -> None:
method action_quit (line 216) | async def action_quit(self) -> None: # type: ignore
method get_system_commands (line 225) | def get_system_commands(self, screen: Screen) -> Iterable[SystemCommand]:
method on_training_status (line 235) | async def on_training_status(self, payload: TrainingStatusPayload) -> ...
FILE: src/lczero_training/tui/data_pipeline_pane.py
class DataPipelinePane (line 26) | class DataPipelinePane(Container):
method __init__ (line 29) | def __init__(self, **kwargs: Any) -> None:
method compose (line 38) | def compose(self) -> ComposeResult:
method _friendly_title (line 42) | def _friendly_title(self, stage_key: str) -> str:
method _ensure_stage_widget (line 47) | def _ensure_stage_widget(
method _ensure_queue_widgets (line 67) | def _ensure_queue_widgets(
method _ensure_statistics_widgets (line 97) | def _ensure_statistics_widgets(
method _mount_widgets (line 124) | def _mount_widgets(
method _ensure_rows (line 132) | def _ensure_rows(
method update_metrics (line 157) | def update_metrics(
FILE: src/lczero_training/tui/dataloader_widgets.py
function _find_stage_metric (line 14) | def _find_stage_metric(
function _collect_metric_names (line 26) | def _collect_metric_names(
function _find_load_metric (line 48) | def _find_load_metric(
function _find_count_metric (line 60) | def _find_count_metric(
function _find_gauge_metric (line 72) | def _find_gauge_metric(
function _find_statistics_metric (line 84) | def _find_statistics_metric(
function _get_queue_metric (line 96) | def _get_queue_metric(
function format_si (line 117) | def format_si(value: int, precision: int = 1) -> str:
function format_full_number (line 135) | def format_full_number(value: int) -> str:
function _format_load (line 141) | def _format_load(
function _format_count (line 155) | def _format_count(
function _format_gauge (line 171) | def _format_gauge(
function _format_statistics (line 184) | def _format_statistics(
function _average_queue_fullness (line 205) | def _average_queue_fullness(
function _canonical_stage_name (line 220) | def _canonical_stage_name(
class BaseRowWidget (line 234) | class BaseRowWidget(Widget):
method __init__ (line 239) | def __init__(
method compose (line 257) | def compose(self) -> ComposeResult:
method on_mount (line 261) | def on_mount(self) -> None:
method add_content_widget (line 276) | def add_content_widget(self, widget: Widget) -> None:
method _update_name (line 281) | def _update_name(
class StageWidget (line 292) | class StageWidget(BaseRowWidget):
method __init__ (line 295) | def __init__(
method _ensure_chip (line 309) | def _ensure_chip(self, key: str, default_text: str, classes: str) -> S...
method _update_last_chunk_chip (line 317) | def _update_last_chunk_chip(
method _update_anchor_chip (line 333) | def _update_anchor_chip(
method update_metrics (line 342) | def update_metrics(
class StatisticsRowWidget (line 404) | class StatisticsRowWidget(BaseRowWidget):
method __init__ (line 407) | def __init__(
method compose (line 425) | def compose(self) -> ComposeResult:
method on_mount (line 429) | def on_mount(self) -> None:
method update_metrics (line 437) | def update_metrics(
class QueueWidget (line 454) | class QueueWidget(BaseRowWidget):
method __init__ (line 457) | def __init__(
method compose (line 490) | def compose(self) -> ComposeResult:
method on_mount (line 495) | def on_mount(self) -> None:
method update_metrics (line 504) | def update_metrics(
FILE: src/lczero_training/tui/log_pane.py
class StreamingLogPane (line 9) | class StreamingLogPane(RichLog):
method __init__ (line 12) | def __init__(
method on_mount (line 28) | def on_mount(self) -> None:
method _write_banner (line 34) | def _write_banner(self) -> None:
method _write_to_file (line 50) | def _write_to_file(self, line: str) -> None:
method _read_stream (line 58) | async def _read_stream(self) -> None:
FILE: src/lczero_training/tui/training_widgets.py
class TimeProgressWidget (line 7) | class TimeProgressWidget(Static):
method __init__ (line 10) | def __init__(self, label: str, *, id: str | None = None) -> None:
method compose (line 14) | def compose(self) -> ComposeResult:
method update_progress (line 19) | def update_progress(
function format_time_duration (line 46) | def format_time_duration(seconds: float) -> str:
class TrainingScheduleWidget (line 63) | class TrainingScheduleWidget(Static):
method compose (line 66) | def compose(self) -> ComposeResult:
method update_training_schedule (line 73) | def update_training_schedule(
FILE: tf/attention_policy_map.py
function make_map (line 39) | def make_map():
function make_pos_enc (line 96) | def make_pos_enc():
FILE: tf/chunkparsefunc.py
function parse_function (line 21) | def parse_function(planes, probs, winner, q, plies_left):
FILE: tf/chunkparser.py
function reverse_expand_bits (line 82) | def reverse_expand_bits(plane):
class ChunkDataSrc (line 88) | class ChunkDataSrc:
method __init__ (line 89) | def __init__(self, items):
method next (line 92) | def next(self):
function chunk_reader (line 98) | def chunk_reader(chunk_filenames, chunk_filename_queue):
class ChunkParser (line 121) | class ChunkParser:
method __init__ (line 123) | def __init__(self,
method shutdown (line 141) | def shutdown(self):
method parse (line 153) | def parse(self):
method sequential (line 156) | def sequential(self):
class ChunkParserInner (line 160) | class ChunkParserInner:
method __init__ (line 161) | def __init__(self, parent, chunks, expected_input_format, shuffle_size,
method init_structs (line 240) | def init_structs(self):
method convert_v6_to_tuple (line 250) | def convert_v6_to_tuple(self, content):
method sample_record (line 405) | def sample_record(self, chunkdata):
method single_file_gen (line 462) | def single_file_gen(self, filename):
method sequential_gen (line 489) | def sequential_gen(self):
method sequential (line 494) | def sequential(self):
method task (line 501) | def task(self, chunk_filename_queue, writer):
method v6_gen (line 512) | def v6_gen(self):
method tuple_gen (line 536) | def tuple_gen(self, gen):
method batch_gen (line 544) | def batch_gen(self, gen, allow_partial=True):
method parse (line 558) | def parse(self):
class ChunkParserTest (line 570) | class ChunkParserTest(unittest.TestCase):
method setUp (line 571) | def setUp(self):
method generate_fake_pos (line 574) | def generate_fake_pos(self):
method v4_record (line 601) | def v4_record(self, planes, i, probs, winner, best_q, best_d):
method test_structsize (line 612) | def test_structsize(self):
method test_parsing (line 618) | def test_parsing(self):
FILE: tf/decode_training.py
class Board (line 277) | class Board:
method __init__ (line 278) | def __init__(self):
method clear_board (line 281) | def clear_board(self):
method describe (line 287) | def describe(self):
class TrainingStep (line 295) | class TrainingStep:
method __init__ (line 296) | def __init__(self, version):
method init_structs (line 329) | def init_structs(self):
method init_move_map (line 333) | def init_move_map(self):
method clear_hist (line 347) | def clear_hist(self):
method update_board (line 351) | def update_board(self, hist, piece, bit_board):
method describe (line 363) | def describe(self):
method update_reals (line 412) | def update_reals(self, text_item):
method flip_single_v1_plane (line 423) | def flip_single_v1_plane(self, plane):
method display_v4 (line 429) | def display_v4(self, ply, content):
function main (line 467) | def main(args):
FILE: tf/lc0_az_policy_map.py
function index_to_position (line 14) | def index_to_position(x):
function position_to_index (line 18) | def position_to_index(p):
function valid_index (line 22) | def valid_index(i):
function queen_move (line 30) | def queen_move(start, direction, steps):
function knight_move (line 49) | def knight_move(start, direction, steps):
function make_map (line 68) | def make_map(kind='matrix'):
FILE: tf/net.py
function nested_getattr (line 20) | def nested_getattr(obj, attr):
class Net (line 27) | class Net:
method __init__ (line 29) | def __init__(self,
method set_networkformat (line 57) | def set_networkformat(self, net):
method set_policyformat (line 63) | def set_policyformat(self, policy):
method set_headcount (line 66) | def set_headcount(self, headcount):
method set_pol_headcount (line 69) | def set_pol_headcount(self, headcount):
method set_valueformat (line 72) | def set_valueformat(self, value):
method set_movesleftformat (line 81) | def set_movesleftformat(self, moves_left):
method set_input (line 84) | def set_input(self, input_format):
method set_defaultactivation (line 94) | def set_defaultactivation(self, activation):
method set_smolgen_activation (line 100) | def set_smolgen_activation(self, activation):
method set_ffn_activation (line 106) | def set_ffn_activation(self, activation):
method activation (line 112) | def activation(self, name):
method get_weight_amounts (line 134) | def get_weight_amounts(self):
method fill_layer_v2 (line 146) | def fill_layer_v2(self, layer, params):
method fill_layer (line 161) | def fill_layer(self, layer, weights):
method fill_conv_block (line 176) | def fill_conv_block(self, convblock, weights, gammas):
method fill_plain_conv (line 190) | def fill_plain_conv(self, convblock, weights):
method fill_se_unit (line 195) | def fill_se_unit(self, se_unit, weights):
method denorm_layer_v2 (line 201) | def denorm_layer_v2(self, layer):
method denorm_layer (line 207) | def denorm_layer(self, layer, weights):
method denorm_conv_block (line 210) | def denorm_conv_block(self, convblock, weights):
method denorm_plain_conv (line 226) | def denorm_plain_conv(self, convblock, weights):
method denorm_se_unit (line 231) | def denorm_se_unit(self, convblock, weights):
method save_txt (line 242) | def save_txt(self, filename):
method save_proto (line 267) | def save_proto(self, filename):
method tf_name_to_pb_name (line 279) | def tf_name_to_pb_name(self, name):
method get_weights_v2 (line 492) | def get_weights_v2(self, names):
method get_weights (line 536) | def get_weights(self):
method filters (line 564) | def filters(self):
method blocks (line 569) | def blocks(self):
method print_stats (line 572) | def print_stats(self):
method parse_proto (line 578) | def parse_proto(self, filename):
method parse_txt (line 595) | def parse_txt(self, filename):
method fill_net_v2 (line 615) | def fill_net_v2(self, all_weights):
method fill_net (line 701) | def fill_net(self, weights):
function print_pb_stats (line 744) | def print_pb_stats(obj, parent=None):
function main (line 761) | def main(argv):
FILE: tf/shufflebuffer.py
class ShuffleBuffer (line 23) | class ShuffleBuffer:
method __init__ (line 24) | def __init__(self, elem_size, elem_count):
method extract (line 42) | def extract(self):
method insert_or_replace (line 56) | def insert_or_replace(self, item):
class ShuffleBufferTest (line 83) | class ShuffleBufferTest(unittest.TestCase):
method test_extract (line 84) | def test_extract(self):
method test_wrong_size (line 95) | def test_wrong_size(self):
method test_insert_or_replace (line 103) | def test_insert_or_replace(self):
FILE: tf/tfprocess.py
function square_relu (line 34) | def square_relu(x):
class Gating (line 38) | class Gating(tf.keras.layers.Layer):
method __init__ (line 40) | def __init__(self, name=None, additive=True, init_value=None, **kwargs):
method build (line 47) | def build(self, input_shape):
method call (line 56) | def call(self, inputs):
function ma_gating (line 61) | def ma_gating(inputs, name):
class ApplySqueezeExcitation (line 67) | class ApplySqueezeExcitation(tf.keras.layers.Layer):
method __init__ (line 69) | def __init__(self, **kwargs):
method build (line 72) | def build(self, input_dimens):
method call (line 75) | def call(self, inputs):
class ApplyPolicyMap (line 85) | class ApplyPolicyMap(tf.keras.layers.Layer):
method __init__ (line 87) | def __init__(self, **kwargs):
method call (line 91) | def call(self, inputs):
class ApplyAttentionPolicyMap (line 97) | class ApplyAttentionPolicyMap(tf.keras.layers.Layer):
method __init__ (line 99) | def __init__(self, **kwargs):
method call (line 103) | def call(self, logits, pp_logits):
class Metric (line 112) | class Metric:
method __init__ (line 114) | def __init__(self, short_name, long_name, suffix='', **kwargs):
method assign (line 121) | def assign(self, value):
method accumulate (line 125) | def accumulate(self, value):
method merge (line 132) | def merge(self, other):
method get (line 137) | def get(self):
method reset (line 142) | def reset(self):
class TFProcess (line 147) | class TFProcess:
method __init__ (line 149) | def __init__(self, cfg):
method init (line 355) | def init(self, train_dataset, test_dataset, validation_dataset=None):
method init_net (line 380) | def init_net(self):
method replace_weights (line 631) | def replace_weights(self, proto_filename, ignore_errors=False):
method restore (line 718) | def restore(self):
method process_loop (line 723) | def process_loop(self, batch_size, test_batches, batch_splits=1):
method read_weights (line 747) | def read_weights(self):
method process_inner_loop (line 751) | def process_inner_loop(self, x, y, z, q, m):
method strategy_process_inner_loop (line 790) | def strategy_process_inner_loop(self, x, y, z, q, m):
method apply_grads (line 799) | def apply_grads(self, grads, effective_batch_splits):
method strategy_apply_grads (line 815) | def strategy_apply_grads(self, grads, effective_batch_splits):
method merge_grads (line 824) | def merge_grads(self, grads, new_grads):
method strategy_merge_grads (line 828) | def strategy_merge_grads(self, grads, new_grads):
method train_step (line 831) | def train_step(self, steps, batch_size, batch_splits):
method process (line 914) | def process(self, batch_size, test_batches, batch_splits):
method calculate_swa_summaries (line 992) | def calculate_swa_summaries(self, test_batches, steps):
method calculate_test_summaries_inner_loop (line 1004) | def calculate_test_summaries_inner_loop(self, x, y, z, q, m):
method strategy_calculate_test_summaries_inner_loop (line 1041) | def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m):
method calculate_test_summaries (line 1050) | def calculate_test_summaries(self, test_batches, steps):
method calculate_swa_validations (line 1082) | def calculate_swa_validations(self, steps):
method calculate_test_validations (line 1093) | def calculate_test_validations(self, steps):
method compute_update_ratio (line 1118) | def compute_update_ratio(self, before_weights, after_weights, steps):
method update_swa (line 1147) | def update_swa(self):
method save_swa_weights (line 1154) | def save_swa_weights(self, filename):
method save_leelaz_weights (line 1162) | def save_leelaz_weights(self, filename):
method batch_norm (line 1169) | def batch_norm(self, input, name, scale=False):
method squeeze_excitation (line 1195) | def squeeze_excitation(self, inputs, channels, name):
method conv_block (line 1211) | def conv_block(self,
method residual_block (line 1228) | def residual_block(self, inputs, channels, name):
method split_heads (line 1257) | def split_heads(inputs, batch_size: int, num_heads: int, depth: int):
method scaled_dot_product_attention (line 1264) | def scaled_dot_product_attention(self,
method mha (line 1294) | def mha(self, inputs, emb_size: int, d_model: int, num_heads: int,
method ffn (line 1332) | def ffn(self, inputs, emb_size: int, dff: int, initializer, name: str):
method encoder_layer (line 1347) | def encoder_layer(self, inputs, emb_size: int, d_model: int,
method smolgen_weights (line 1389) | def smolgen_weights(self,
method create_residual_body (line 1417) | def create_residual_body(self, inputs):
method create_encoder_body (line 1429) | def create_encoder_body(self, inputs, embedding_size):
method apply_promotion_logits (line 1470) | def apply_promotion_logits(self, queries, keys, attn_wts):
method construct_net (line 1523) | def construct_net(self, inputs, name=''):
FILE: tf/train.py
function get_chunks (line 32) | def get_chunks(data_prefix):
function get_all_chunks (line 36) | def get_all_chunks(path):
function get_latest_chunks (line 50) | def get_latest_chunks(path, num_chunks, allow_less, sort_key_fn):
function identity_function (line 77) | def identity_function(name):
function game_number_for_name (line 81) | def game_number_for_name(name):
function get_input_mode (line 87) | def get_input_mode(cfg):
function main (line 109) | def main(cmd):
FILE: tf/update_steps.py
function main (line 12) | def main(cmd):
Condensed preview — 219 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,196K chars).
[
{
"path": ".clang-format",
"chars": 20,
"preview": "BasedOnStyle: Google"
},
{
"path": ".gitignore",
"chars": 1445,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n*_pb2.py\n\n# C extensions\n*.so\n\n# Distribution "
},
{
"path": ".gitmodules",
"chars": 89,
"preview": "[submodule \"libs/lc0\"]\n\tpath = libs/lc0\n\turl = https://github.com/LeelaChessZero/lc0.git\n"
},
{
"path": "AGENTS.md",
"chars": 2324,
"preview": "# AGENTS.md\n\nThis repository contains training script for the Leela Chess Zero project.\nThey are being rewritten.\n\n* Old"
},
{
"path": "README.md",
"chars": 4020,
"preview": "# Training\n\nThe training pipeline resides in `tf`, this requires tensorflow running on linux (Ubuntu 16.04 in this case)"
},
{
"path": "csrc/loader/chunk_source/chunk_source.h",
"chars": 1093,
"preview": "#pragma once\n\n#include <cstddef>\n#include <optional>\n#include <string>\n#include <vector>\n\n#include \"loader/frame_type.h\""
},
{
"path": "csrc/loader/chunk_source/chunk_source_view.h",
"chars": 1415,
"preview": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <memory>\n#include <optional>\n#include <string>\n#include <ve"
},
{
"path": "csrc/loader/chunk_source/debug_chunk_source.cc",
"chars": 1849,
"preview": "#include \"loader/chunk_source/debug_chunk_source.h\"\n\n#include <algorithm>\n#include <cinttypes>\n#include <cmath>\n#include"
},
{
"path": "csrc/loader/chunk_source/debug_chunk_source.h",
"chars": 1564,
"preview": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <optional>\n#include <random>\n#include <string>\n\n#include \"l"
},
{
"path": "csrc/loader/chunk_source/rawfile_chunk_source.cc",
"chars": 1804,
"preview": "#include \"loader/chunk_source/rawfile_chunk_source.h\"\n\n#include <absl/log/log.h>\n\n#include <fstream>\n#include <stdexcept"
},
{
"path": "csrc/loader/chunk_source/rawfile_chunk_source.h",
"chars": 795,
"preview": "#pragma once\n\n#include <filesystem>\n#include <string>\n\n#include \"loader/chunk_source/chunk_source.h\"\n#include \"proto/dat"
},
{
"path": "csrc/loader/chunk_source/tar_chunk_source.cc",
"chars": 7562,
"preview": "#include \"loader/chunk_source/tar_chunk_source.h\"\n\n#include <absl/log/log.h>\n#include <absl/strings/str_cat.h>\n#include "
},
{
"path": "csrc/loader/chunk_source/tar_chunk_source.h",
"chars": 1235,
"preview": "#pragma once\n\n#include <sys/types.h>\n\n#include <filesystem>\n#include <string>\n#include <vector>\n\n#include \"loader/chunk_"
},
{
"path": "csrc/loader/data_loader.cc",
"chars": 8526,
"preview": "#include \"loader/data_loader.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/log/log.h>\n#include <absl/strings/"
},
{
"path": "csrc/loader/data_loader.h",
"chars": 2108,
"preview": "#pragma once\n\n#include <memory>\n#include <string>\n#include <string_view>\n#include <thread>\n#include <utility>\n#include <"
},
{
"path": "csrc/loader/data_loader_metrics.cc",
"chars": 4062,
"preview": "// ABOUTME: Implementation of UpdateFrom functions for data loader metric\n// protobuf messages. ABOUTME: Handles aggrega"
},
{
"path": "csrc/loader/data_loader_metrics.h",
"chars": 1303,
"preview": "// ABOUTME: Header for UpdateFrom functions for data loader metric protobuf\n// messages. ABOUTME: Declares functions for"
},
{
"path": "csrc/loader/data_loader_test.cc",
"chars": 907,
"preview": "#include \"loader/data_loader.h\"\n\n#include <gtest/gtest.h>\n\n#include <stdexcept>\n\nnamespace lczero {\nnamespace training {"
},
{
"path": "csrc/loader/frame_type.h",
"chars": 1371,
"preview": "/*\n This file is part of Leela Chess Zero.\n Copyright (C) 2025 The LCZero Authors\n\n Leela Chess is free software: you"
},
{
"path": "csrc/loader/loader_main.cpp",
"chars": 4177,
"preview": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h"
},
{
"path": "csrc/loader/pybind_module.cc",
"chars": 7412,
"preview": "// ABOUTME: PyBind11 binding module exposing C++ DataLoader to Python.\n// ABOUTME: Handles configuration conversion and "
},
{
"path": "csrc/loader/stages/chunk_rescorer.cc",
"chars": 5761,
"preview": "#include \"loader/stages/chunk_rescorer.h\"\n\n#include <stdexcept>\n#include <utility>\n\n#include \"absl/base/call_once.h\"\n#in"
},
{
"path": "csrc/loader/stages/chunk_rescorer.h",
"chars": 1793,
"preview": "// ABOUTME: Stage that rescales training chunks using Syzygy tablebases.\n// ABOUTME: Adjusts frame metadata by invoking "
},
{
"path": "csrc/loader/stages/chunk_rescorer_test.cc",
"chars": 2073,
"preview": "#include \"loader/stages/chunk_rescorer.h\"\n\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"gtest/gtest"
},
{
"path": "csrc/loader/stages/chunk_source_loader.cc",
"chars": 6467,
"preview": "#include \"loader/stages/chunk_source_loader.h\"\n\n#include <filesystem>\n#include <utility>\n\n#include \"absl/log/log.h\"\n#inc"
},
{
"path": "csrc/loader/stages/chunk_source_loader.h",
"chars": 2219,
"preview": "#pragma once\n\n#include <atomic>\n#include <filesystem>\n#include <memory>\n#include <stop_token>\n\n#include \"absl/base/threa"
},
{
"path": "csrc/loader/stages/chunk_source_loader_test.cc",
"chars": 6242,
"preview": "#include \"loader/stages/chunk_source_loader.h\"\n\n#include <gtest/gtest.h>\n\n#include <filesystem>\n\n#include \"loader/stages"
},
{
"path": "csrc/loader/stages/chunk_source_splitter.cc",
"chars": 6049,
"preview": "#include \"loader/stages/chunk_source_splitter.h\"\n\n#include <algorithm>\n#include <numeric>\n#include <utility>\n\n#include \""
},
{
"path": "csrc/loader/stages/chunk_source_splitter.h",
"chars": 1912,
"preview": "#pragma once\n\n#include <memory>\n#include <stop_token>\n#include <string>\n#include <string_view>\n#include <utility>\n#inclu"
},
{
"path": "csrc/loader/stages/chunk_source_splitter_test.cc",
"chars": 4610,
"preview": "#include \"loader/stages/chunk_source_splitter.h\"\n\n#include <memory>\n#include <string>\n#include <utility>\n#include <vecto"
},
{
"path": "csrc/loader/stages/chunk_unpacker.cc",
"chars": 10887,
"preview": "#include \"loader/stages/chunk_unpacker.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/container/flat_hash_set."
},
{
"path": "csrc/loader/stages/chunk_unpacker.h",
"chars": 2001,
"preview": "// ABOUTME: Stage that unpacks chunks into FrameType frames.\n// ABOUTME: Converts stream of std::string chunks to FrameT"
},
{
"path": "csrc/loader/stages/chunk_unpacker_test.cc",
"chars": 8390,
"preview": "#include \"loader/stages/chunk_unpacker.h\"\n\n#include <iterator>\n#include <string>\n#include <utility>\n#include <vector>\n\n#"
},
{
"path": "csrc/loader/stages/file_path_provider.cc",
"chars": 12484,
"preview": "#include \"loader/stages/file_path_provider.h\"\n\n#include <absl/cleanup/cleanup.h>\n#include <absl/container/flat_hash_set."
},
{
"path": "csrc/loader/stages/file_path_provider.h",
"chars": 2967,
"preview": "#pragma once\n\n#include <absl/base/thread_annotations.h>\n#include <absl/container/flat_hash_map.h>\n#include <absl/log/log"
},
{
"path": "csrc/loader/stages/file_path_provider_main.cc",
"chars": 1778,
"preview": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h"
},
{
"path": "csrc/loader/stages/file_path_provider_test.cc",
"chars": 6625,
"preview": "#include \"loader/stages/file_path_provider.h\"\n\n#include <gtest/gtest.h>\n\n#include <chrono>\n#include <filesystem>\n#includ"
},
{
"path": "csrc/loader/stages/join_stage.cc",
"chars": 2603,
"preview": "#include \"loader/stages/join_stage.h\"\n\n#include <absl/log/log.h>\n\n#include \"loader/data_loader_metrics.h\"\n#include \"prot"
},
{
"path": "csrc/loader/stages/join_stage.h",
"chars": 1559,
"preview": "// ABOUTME: Stage that joins multiple input queues into a single output.\n// ABOUTME: Spawns one thread per input to read"
},
{
"path": "csrc/loader/stages/join_stage_test.cc",
"chars": 3996,
"preview": "#include \"loader/stages/join_stage.h\"\n\n#include <memory>\n#include <vector>\n\n#include \"absl/container/flat_hash_set.h\"\n#i"
},
{
"path": "csrc/loader/stages/position_sampling.cc",
"chars": 914,
"preview": "#include \"loader/stages/position_sampling.h\"\n\n#include <cmath>\n\nnamespace lczero {\nnamespace training {\n\nfloat ComputePo"
},
{
"path": "csrc/loader/stages/position_sampling.h",
"chars": 311,
"preview": "#pragma once\n\n#include \"loader/frame_type.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lczero {\nnamespace trai"
},
{
"path": "csrc/loader/stages/shuffling_chunk_pool.cc",
"chars": 34525,
"preview": "#include \"loader/stages/shuffling_chunk_pool.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/base/thread_annota"
},
{
"path": "csrc/loader/stages/shuffling_chunk_pool.h",
"chars": 5724,
"preview": "#pragma once\n\n#include <atomic>\n#include <filesystem>\n#include <memory>\n#include <optional>\n#include <stop_token>\n#inclu"
},
{
"path": "csrc/loader/stages/shuffling_chunk_pool_test.cc",
"chars": 26316,
"preview": "// ABOUTME: Comprehensive unit tests for the ShufflingChunkPool class\n// ABOUTME: Tests chunk source management, output "
},
{
"path": "csrc/loader/stages/shuffling_frame_sampler.cc",
"chars": 3873,
"preview": "#include \"loader/stages/shuffling_frame_sampler.h\"\n\n#include <utility>\n\n#include \"absl/algorithm/container.h\"\n#include \""
},
{
"path": "csrc/loader/stages/shuffling_frame_sampler.h",
"chars": 1925,
"preview": "// ABOUTME: Stage that provides shuffled frames using reservoir sampling.\n// ABOUTME: Takes FrameType frames and outputs"
},
{
"path": "csrc/loader/stages/shuffling_frame_sampler_test.cc",
"chars": 6191,
"preview": "#include \"loader/stages/shuffling_frame_sampler.h\"\n\n#include <set>\n#include <vector>\n\n#include \"gtest/gtest.h\"\n#include "
},
{
"path": "csrc/loader/stages/simple_chunk_extractor.cc",
"chars": 3003,
"preview": "#include \"loader/stages/simple_chunk_extractor.h\"\n\n#include <absl/algorithm/container.h>\n#include <absl/log/log.h>\n\n#inc"
},
{
"path": "csrc/loader/stages/simple_chunk_extractor.h",
"chars": 1596,
"preview": "#pragma once\n\n#include <atomic>\n#include <memory>\n#include <optional>\n#include <stop_token>\n#include <string>\n#include <"
},
{
"path": "csrc/loader/stages/simple_chunk_extractor_test.cc",
"chars": 6941,
"preview": "#include \"loader/stages/simple_chunk_extractor.h\"\n\n#include <gmock/gmock.h>\n#include <gtest/gtest.h>\n\n#include <memory>\n"
},
{
"path": "csrc/loader/stages/stage.cc",
"chars": 1169,
"preview": "#include \"loader/stages/stage.h\"\n\n#include <absl/algorithm/container.h>\n\nnamespace lczero {\nnamespace training {\n\nvoid S"
},
{
"path": "csrc/loader/stages/stage.h",
"chars": 4633,
"preview": "#pragma once\n\n#include <algorithm>\n#include <optional>\n#include <stdexcept>\n#include <string>\n#include <string_view>\n#in"
},
{
"path": "csrc/loader/stages/stage_factory.cc",
"chars": 2883,
"preview": "#include \"loader/stages/stage_factory.h\"\n\n#include <stdexcept>\n#include <string>\n\n#include \"loader/stages/chunk_rescorer"
},
{
"path": "csrc/loader/stages/stage_factory.h",
"chars": 261,
"preview": "#pragma once\n\n#include <memory>\n\n#include \"loader/stages/stage.h\"\n#include \"proto/data_loader_config.pb.h\"\n\nnamespace lc"
},
{
"path": "csrc/loader/stages/stage_factory_test.cc",
"chars": 802,
"preview": "#include \"loader/stages/stage_factory.h\"\n\n#include <gtest/gtest.h>\n\n#include <stdexcept>\n\nnamespace lczero {\nnamespace t"
},
{
"path": "csrc/loader/stages/tensor_generator.cc",
"chars": 7836,
"preview": "// ABOUTME: Implementation of TensorGenerator stage for training pipeline.\n// ABOUTME: Converts V6TrainingData frames to"
},
{
"path": "csrc/loader/stages/tensor_generator.h",
"chars": 1838,
"preview": "// ABOUTME: Stage that converts FrameType frames into tensor batches.\n// ABOUTME: Produces TrainingTensors with tensors "
},
{
"path": "csrc/loader/stages/tensor_generator_test.cc",
"chars": 11787,
"preview": "// ABOUTME: Unit tests for TensorGenerator stage in training pipeline.\n// ABOUTME: Tests tensor conversion, batching, an"
},
{
"path": "csrc/loader/stages/training_chunk.h",
"chars": 493,
"preview": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <string>\n#include <vector>\n\n#include \"loader/frame_type.h\"\n"
},
{
"path": "csrc/tools/dump_chunk_main.cc",
"chars": 2860,
"preview": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h"
},
{
"path": "csrc/tools/filter_chunks_main.cc",
"chars": 6839,
"preview": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/glob"
},
{
"path": "csrc/tools/position_weight_stats_main.cc",
"chars": 6750,
"preview": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/glob"
},
{
"path": "csrc/tools/rescore_chunk_main.cc",
"chars": 5958,
"preview": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h"
},
{
"path": "csrc/tools/result_distribution_main.cc",
"chars": 6300,
"preview": "#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/globals.h>\n#include <absl/log/initialize.h"
},
{
"path": "csrc/tools/startpos_policy_distribution_main.cc",
"chars": 4729,
"preview": "#include <absl/algorithm/container.h>\n#include <absl/flags/flag.h>\n#include <absl/flags/parse.h>\n#include <absl/log/glob"
},
{
"path": "csrc/utils/gz.cc",
"chars": 1252,
"preview": "#include \"utils/gz.h\"\n\n#include <absl/log/log.h>\n#include <zlib.h>\n\n#include <array>\n#include <stdexcept>\n\nnamespace lcz"
},
{
"path": "csrc/utils/gz.h",
"chars": 314,
"preview": "#pragma once\n\n#include <span>\n#include <stdexcept>\n#include <string>\n\nnamespace lczero {\nnamespace training {\n\nclass Gun"
},
{
"path": "csrc/utils/metrics/exponential_aggregator.h",
"chars": 12801,
"preview": "#pragma once\n\n#include <absl/strings/str_cat.h>\n#include <absl/synchronization/mutex.h>\n\n#include <chrono>\n#include <cma"
},
{
"path": "csrc/utils/metrics/group.h",
"chars": 3013,
"preview": "#pragma once\n\n#include <absl/strings/str_cat.h>\n\n#include \"utils/metrics/printer.h\"\n\nnamespace lczero {\n\n// Metric is a "
},
{
"path": "csrc/utils/metrics/load_metric.h",
"chars": 3561,
"preview": "#pragma once\n\n#include <absl/base/thread_annotations.h>\n#include <absl/synchronization/mutex.h>\n\n#include <chrono>\n#incl"
},
{
"path": "csrc/utils/metrics/load_metric_test.cc",
"chars": 11090,
"preview": "#include \"utils/metrics/load_metric.h\"\n\n#include <gtest/gtest.h>\n\n#include <chrono>\n#include <memory>\n\n#include \"proto/t"
},
{
"path": "csrc/utils/metrics/printer.h",
"chars": 1343,
"preview": "#pragma once\n\n#include <absl/strings/str_cat.h>\n\n#include <string>\n#include <string_view>\n\nnamespace lczero {\n\nclass Met"
},
{
"path": "csrc/utils/metrics/statistics_metric.h",
"chars": 1847,
"preview": "#pragma once\n\n#include <algorithm>\n\n#include \"proto/training_metrics.pb.h\"\n\nnamespace lczero {\n\n// Helper function to ad"
},
{
"path": "csrc/utils/metrics/stats_test.cc",
"chars": 20179,
"preview": "#include <gtest/gtest.h>\n\n#include <chrono>\n#include <memory>\n#include <optional>\n#include <thread>\n\n#include \"utils/met"
},
{
"path": "csrc/utils/queue.h",
"chars": 21330,
"preview": "#pragma once\n\n#include <algorithm>\n#include <optional>\n#include <stdexcept>\n#include <stop_token>\n\n#include \"absl/base/t"
},
{
"path": "csrc/utils/queue_test.cc",
"chars": 41978,
"preview": "// ABOUTME: Comprehensive unit tests for the Queue template class\n// ABOUTME: Tests thread-safe operations, blocking beh"
},
{
"path": "csrc/utils/stream_shuffler.cc",
"chars": 3558,
"preview": "#include \"utils/stream_shuffler.h\"\n\nnamespace lczero {\nnamespace training {\n\nvoid StreamShuffler::SetUpperBound(size_t u"
},
{
"path": "csrc/utils/stream_shuffler.h",
"chars": 1684,
"preview": "#pragma once\n\n#include <absl/container/fixed_array.h>\n#include <absl/random/random.h>\n\n#include <cstddef>\n#include <dequ"
},
{
"path": "csrc/utils/stream_shuffler_test.cc",
"chars": 7694,
"preview": "#include \"utils/stream_shuffler.h\"\n\n#include <absl/container/flat_hash_set.h>\n#include <absl/random/random.h>\n#include <"
},
{
"path": "csrc/utils/tensor.h",
"chars": 3790,
"preview": "#pragma once\n\n#include <stdexcept>\n#include <string>\n#include <type_traits>\n#include <vector>\n\n#include \"absl/algorithm/"
},
{
"path": "csrc/utils/tensor_test.cc",
"chars": 4366,
"preview": "// ABOUTME: Unit tests for tensor classes and their data access methods.\n// ABOUTME: Tests construction, element access,"
},
{
"path": "csrc/utils/thread_pool.h",
"chars": 6831,
"preview": "#pragma once\n\n#include <cstddef>\n#include <deque>\n#include <exception>\n#include <functional>\n#include <future>\n#include "
},
{
"path": "csrc/utils/training_data_printer.cc",
"chars": 4852,
"preview": "#include \"utils/training_data_printer.h\"\n\n#include <absl/strings/str_format.h>\n\n#include <algorithm>\n#include <iostream>"
},
{
"path": "csrc/utils/training_data_printer.h",
"chars": 1333,
"preview": "#ifndef LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_\n#define LCZERO_TRAINING_UTILS_TRAINING_DATA_PRINTER_H_\n\n#include "
},
{
"path": "docs/README.md",
"chars": 14878,
"preview": "# Running \"new\" training pipeline\n\nNote that the code is still in active development, so things change a lot.\nThe curren"
},
{
"path": "docs/architecture.md",
"chars": 3093,
"preview": "# Architecture Overview\n\nThe document outlines the architecture of the new Leela Chess Zero training\nsystem. The trainin"
},
{
"path": "docs/checkpoint_migration.md",
"chars": 6008,
"preview": "# Checkpoint Migration\n\nWhen part of the model or training setup changes, JAX training state checkpoints\nmay become inco"
},
{
"path": "docs/example.textproto",
"chars": 5884,
"preview": "# Example configuration file for lczero-training\n# This file demonstrates all available configuration options with their"
},
{
"path": "docs/heads.md",
"chars": 6086,
"preview": "# Neural Network Heads Documentation\n\nThis document describes the various policy and value heads used in the network, th"
},
{
"path": "docs/index.md",
"chars": 505,
"preview": "# Index\n\n* [Overview, glossary and file formats](overview.md) — A an overview of the\n project, including definitions of"
},
{
"path": "docs/loader.md",
"chars": 9530,
"preview": "# Data Loader\n\nThe Data Loader is a C++ module (exposed to Python via pybind11) that handles\nloading, preprocessing, shu"
},
{
"path": "docs/new_stage.md",
"chars": 6810,
"preview": "# Writing a New Data Loader Stage\n\nThis guide walks through the lifecycle of adding another stage to the dynamic\ndata lo"
},
{
"path": "docs/overview.md",
"chars": 715,
"preview": "# Overview, Glossary, and File Formats\n\nThis document serves as a glossary of terms used in the project and describes\nth"
},
{
"path": "docs/shuffling_pool_hanse_sampling.md",
"chars": 2480,
"preview": "# Implement single position sampling in Shuffling Pool\n\nThis document defines a new way of sampling in\n[Shuffling Pool]("
},
{
"path": "docs/training_tuple.md",
"chars": 1424,
"preview": "# Training Tuple Format\n\nThe `convert_v6_to_tuple` function in `tf/chunkparser.py` processes training\ndata and produces "
},
{
"path": "docs/tui.md",
"chars": 5748,
"preview": "# UI Design\n\nThe application will present a single-screen dashboard with a classic blue background, organized into sever"
},
{
"path": "docs/weights_tool.md",
"chars": 10793,
"preview": "# lc0-weights - Weight Manipulation Tool\n\n## Overview\n\n`lc0-weights` is a command-line tool and Python library for manip"
},
{
"path": "init.sh",
"chars": 109,
"preview": "#!/usr/bin/env bash\n\nprotoc --proto_path=libs/lc0 --python_out=tf proto/net.proto\ntouch tf/proto/__init__.py\n"
},
{
"path": "justfile",
"chars": 2521,
"preview": "# List available commands\ndefault:\n @just --list\n\n# Check if all C++ files in csrc/ are formatted according to clang-"
},
{
"path": "meson.build",
"chars": 12365,
"preview": "project(\n 'lczero-training',\n 'cpp',\n version : '0.1',\n meson_version : '>= 1.3.0',\n default_options : [\n 'warni"
},
{
"path": "native.ini",
"chars": 59,
"preview": "[binaries]\npython = '@GLOBAL_SOURCE_ROOT@/.venv/bin/python'"
},
{
"path": "proto/checkpoint_migration_config.proto",
"chars": 289,
"preview": "syntax = \"proto3\";\n\npackage lczero_training.proto;\n\nmessage CheckpointMigrationRule {\n // Path in the old state pytree."
},
{
"path": "proto/data_loader_config.proto",
"chars": 8330,
"preview": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Configuration for output queue used by stages.\nmessage QueueConfig {\n "
},
{
"path": "proto/export_config.proto",
"chars": 587,
"preview": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Configuration for model export settings.\nmessage ExportConfig {\n // De"
},
{
"path": "proto/metrics_config.proto",
"chars": 1334,
"preview": "syntax = \"proto2\";\n\npackage lczero.training;\n\n// Sentinel message for training batch sample type.\nmessage TrainingBatch "
},
{
"path": "proto/model_config.proto",
"chars": 1581,
"preview": "syntax = \"proto2\";\n\nimport \"proto/net.proto\";\nimport \"proto/hlo.proto\";\n\npackage lczero.training;\n\nmessage ModelConfig {"
},
{
"path": "proto/root_config.proto",
"chars": 815,
"preview": "syntax = \"proto2\";\n\npackage lczero.training;\n\nimport \"proto/data_loader_config.proto\";\nimport \"proto/model_config.proto\""
},
{
"path": "proto/stage_control.proto",
"chars": 501,
"preview": "syntax = \"proto2\";\n\npackage lczero.training;\n\nmessage ShufflingChunkPoolControlRequest {\n optional bool reset_chunk_anc"
},
{
"path": "proto/training_config.proto",
"chars": 4996,
"preview": "syntax = \"proto3\";\n\npackage lczero.training;\n\n// Configuration for training algorithm and parameters.\nmessage TrainingCo"
},
{
"path": "proto/training_metrics.proto",
"chars": 2132,
"preview": "syntax = \"proto2\";\n\npackage lczero;\n\n// Load metric that accumulates seconds of load time.\n// Separate proto to support "
},
{
"path": "pyproject.toml",
"chars": 2994,
"preview": "[project]\nname = \"lczero-training\"\nversion = \"0.1.0\"\ndescription = \"Training scripts and data loading for Leela Chess Ze"
},
{
"path": "scripts/diff.py",
"chars": 929,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n ids = []\n for d in"
},
{
"path": "scripts/fixorder.py",
"chars": 720,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n ids = []\n for d in"
},
{
"path": "scripts/init.sh",
"chars": 827,
"preview": "#!/usr/bin/env bash\n\nset -e\n\nWINDOWSIZE=80000\nROOT=\"/work/lc0\"\n\necho \"Cleaning up data directory\"\nrm -rf $ROOT/data\nmkdi"
},
{
"path": "scripts/initsplit.py",
"chars": 1336,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n ids = []\n for d in"
},
{
"path": "scripts/inittrainingname.py",
"chars": 1145,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n ids = []\n for d in"
},
{
"path": "scripts/pack.py",
"chars": 2597,
"preview": "#!/usr/bin/env python3\n\nimport glob\nimport os\nimport argparse\nimport gzip\nimport bz2\nimport struct\nimport numpy as np\nfr"
},
{
"path": "scripts/purge.py",
"chars": 797,
"preview": "#!/usr/bin/env python\n\nimport glob\nimport os\nimport argparse\n\n\ndef get_sorted_chunk_ids(dirs):\n ids = []\n for d in"
},
{
"path": "scripts/rescore.sh",
"chars": 810,
"preview": "#!/usr/bin/env bash\n\nset -e\n\nROOT=\"/work/lc0/dev2\"\nRESCORER=\"$HOME/bin/rescorer\"\n\nfunction usage()\n{\n echo \"Rescores st"
},
{
"path": "scripts/shuffle.py",
"chars": 1799,
"preview": "#!/usr/bin/python3\nimport gzip\nimport sys\nimport glob\nimport os\nimport random\nfrom multiprocessing import Pool\nimport tq"
},
{
"path": "scripts/split.sh",
"chars": 4221,
"preview": "#!/usr/bin/env bash\n\nRECORDSIZE=8276 # size in bytes of a record (s, pi, v)\n\nfunction usage()\n{\n echo \"Watches a direct"
},
{
"path": "scripts/stage.sh",
"chars": 1064,
"preview": "#!/usr/bin/env bash\n\nset -e\n\nfunction usage()\n{\n echo \"Moves arriving data to a directory so rescorer can assume all fi"
},
{
"path": "scripts/unpack.py",
"chars": 1523,
"preview": "#!/usr/bin/env python3\n\nimport os\nimport argparse\nimport gzip\nimport bz2\nimport numpy as np\nimport pickle\nimport struct\n"
},
{
"path": "scripts/upload.sh",
"chars": 1366,
"preview": "#!/usr/bin/env bash\n\nset -e\n\nfunction usage()\n{\n echo \"Uploads a network with NxM prefix, where N=filters and M=blocks\""
},
{
"path": "src/lczero_training/__init__.py",
"chars": 41,
"preview": "\"\"\"Leela Chess Zero training package.\"\"\"\n"
},
{
"path": "src/lczero_training/_lczero_training.pyi",
"chars": 1283,
"preview": "# ABOUTME: Type stubs for C++ DataLoader PyBind11 bindings.\n# ABOUTME: Provides type annotations for _lczero_training co"
},
{
"path": "src/lczero_training/commands/__init__.py",
"chars": 558,
"preview": "\"\"\"Command entrypoint scaffolding and shared CLI helpers.\n\nThis package will host thin wrappers for individual tools (co"
},
{
"path": "src/lczero_training/commands/backfill_metrics.py",
"chars": 1663,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_train"
},
{
"path": "src/lczero_training/commands/common.py",
"chars": 1692,
"preview": "import argparse\nimport logging\nimport os\nimport sys\n\n_DEFAULT_FORMAT = (\n \"%(levelname).1s%(asctime)s.%(msecs)03d %(n"
},
{
"path": "src/lczero_training/commands/daemon.py",
"chars": 676,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\nfrom lczero_train"
},
{
"path": "src/lczero_training/commands/dataloader_viz.py",
"chars": 2929,
"preview": "import argparse\nimport sys\n\nfrom google.protobuf import text_format\nfrom graphviz import Digraph # type: ignore\n\nfrom l"
},
{
"path": "src/lczero_training/commands/describe_training.py",
"chars": 1243,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/jax2leela.py",
"chars": 3724,
"preview": "import argparse\nimport gzip\nimport logging\nimport os\nimport sys\n\nimport orbax.checkpoint as ocp\nfrom google.protobuf imp"
},
{
"path": "src/lczero_training/commands/leela2jax.py",
"chars": 1959,
"preview": "import argparse\nimport sys\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n parser = argparse.ArgumentParser(\n "
},
{
"path": "src/lczero_training/commands/migrate_checkpoint.py",
"chars": 2565,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/overfit.py",
"chars": 1431,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/test_dataloader.py",
"chars": 1279,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/train.py",
"chars": 4314,
"preview": "import argparse\nimport datetime\nimport gzip\nimport logging\nimport os\nimport sys\n\nimport orbax.checkpoint as ocp\nfrom fla"
},
{
"path": "src/lczero_training/commands/training_eval.py",
"chars": 2254,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/training_init.py",
"chars": 2294,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/tui.py",
"chars": 711,
"preview": "import argparse\nimport logging\nimport sys\n\nimport anyio\n\nfrom lczero_training.commands import configure_root_logging\nfro"
},
{
"path": "src/lczero_training/commands/tune_lr.py",
"chars": 2641,
"preview": "import argparse\nimport logging\nimport sys\n\nfrom lczero_training.commands import configure_root_logging\n\n\ndef _build_pars"
},
{
"path": "src/lczero_training/commands/weights_tool.py",
"chars": 3124,
"preview": "\"\"\"CLI command for manipulating Lc0 neural network weights.\"\"\"\n\nimport argparse\nimport sys\n\nimport numpy as np\n\nfrom lcz"
},
{
"path": "src/lczero_training/convert/__init__.py",
"chars": 53,
"preview": "\"\"\"Convert package for Leela Chess Zero training.\"\"\"\n"
},
{
"path": "src/lczero_training/convert/jax_to_leela.py",
"chars": 3825,
"preview": "import dataclasses\nimport logging\nfrom typing import Optional, cast\n\nimport numpy as np\nfrom flax import nnx\n\nfrom lczer"
},
{
"path": "src/lczero_training/convert/leela_pytree_visitor.py",
"chars": 7042,
"preview": "import math\nfrom typing import Any, Optional\n\nfrom flax import nnx\n\nfrom proto import net_pb2\n\n\nclass LeelaPytreeWeights"
},
{
"path": "src/lczero_training/convert/leela_to_jax.py",
"chars": 6444,
"preview": "import dataclasses\nimport gzip\nimport logging\nimport math\nfrom typing import Optional, cast\n\nimport jax.numpy as jnp\nfro"
},
{
"path": "src/lczero_training/convert/leela_to_modelconfig.py",
"chars": 5138,
"preview": "from proto import hlo_pb2, model_config_pb2, net_pb2\n\n\ndef _defaultactivation_to_activation(\n activation: net_pb2.Net"
},
{
"path": "src/lczero_training/daemon/__init__.py",
"chars": 132,
"preview": "# ABOUTME: Daemon package for training subprocess communication.\n# ABOUTME: Provides TrainingDaemon class for IPC via st"
},
{
"path": "src/lczero_training/daemon/daemon.py",
"chars": 4938,
"preview": "import logging\nimport signal\nimport sys\nimport threading\nimport time\n\nimport anyio\n\nimport proto.training_metrics_pb2 as"
},
{
"path": "src/lczero_training/daemon/metrics.py",
"chars": 10347,
"preview": "\"\"\"Metrics collection and logging for training daemon.\"\"\"\n\nimport logging\nimport os\nfrom abc import ABC, abstractmethod\n"
},
{
"path": "src/lczero_training/daemon/metrics_base.py",
"chars": 1007,
"preview": "\"\"\"Base classes for metrics.\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom flax import nnx\n\nfrom lczero_training.trainin"
},
{
"path": "src/lczero_training/daemon/pipeline.py",
"chars": 17722,
"preview": "import dataclasses\nimport datetime\nimport gzip\nimport logging\nimport os\nimport threading\nimport time\nfrom pathlib import"
},
{
"path": "src/lczero_training/daemon/protocol/__init__.py",
"chars": 157,
"preview": "# ABOUTME: Protocol package for JSONL IPC communication between processes.\n# ABOUTME: Contains registry system, message "
},
{
"path": "src/lczero_training/daemon/protocol/communicator.py",
"chars": 7725,
"preview": "# ABOUTME: Core Communicator class for JSONL IPC between processes.\n# ABOUTME: Handles serialization/deserialization and"
},
{
"path": "src/lczero_training/daemon/protocol/messages.py",
"chars": 1491,
"preview": "# ABOUTME: Payload dataclass definitions for JSONL IPC protocol messages.\n# ABOUTME: Defines minimal event types for tra"
},
{
"path": "src/lczero_training/daemon/protocol/registry.py",
"chars": 1049,
"preview": "# ABOUTME: Registry system for mapping event type strings to payload dataclasses.\n# ABOUTME: Provides @register decorato"
},
{
"path": "src/lczero_training/daemon/rms_metrics.py",
"chars": 3627,
"preview": "\"\"\"RMS metrics for model parameters.\"\"\"\n\nfrom typing import Any, cast\n\nimport jax\nimport jax.numpy as jnp\nfrom flax impo"
},
{
"path": "src/lczero_training/dataloader/__init__.py",
"chars": 330,
"preview": "from lczero_training._lczero_training import (\n DataLoader,\n TensorBase,\n)\nfrom proto.data_loader_config_pb2 impor"
},
{
"path": "src/lczero_training/model/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/lczero_training/model/embedding.py",
"chars": 3189,
"preview": "import jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .shared import Ffn\nfro"
},
{
"path": "src/lczero_training/model/encoder.py",
"chars": 7131,
"preview": "import math\nfrom typing import Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\nfrom flax.linen import "
},
{
"path": "src/lczero_training/model/loss_function.py",
"chars": 14860,
"preview": "from typing import Dict, List, Optional, Sequence, Tuple, Union, cast\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nf"
},
{
"path": "src/lczero_training/model/model.py",
"chars": 4144,
"preview": "import dataclasses\nimport math\nfrom typing import Optional, Tuple\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import n"
},
{
"path": "src/lczero_training/model/movesleft_head.py",
"chars": 1067,
"preview": "import jax\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .utils import get_activation\n\n\nclass MovesLeft"
},
{
"path": "src/lczero_training/model/policy_head.py",
"chars": 13705,
"preview": "import math\nfrom typing import Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import nnx\n\nfrom proto import mode"
},
{
"path": "src/lczero_training/model/shared.py",
"chars": 1205,
"preview": "import jax\nfrom flax import nnx\nfrom flax.linen import initializers as flax_initializers\n\nfrom proto import net_pb2\n\nfro"
},
{
"path": "src/lczero_training/model/utils.py",
"chars": 1766,
"preview": "from typing import Any\n\nimport jax.numpy as jnp\nfrom flax import nnx\nfrom jax.nn import mish\n\nfrom proto import net_pb2\n"
},
{
"path": "src/lczero_training/model/value_head.py",
"chars": 1968,
"preview": "from typing import Optional, Tuple\n\nimport jax\nfrom flax import nnx\n\nfrom proto import model_config_pb2\n\nfrom .utils imp"
},
{
"path": "src/lczero_training/py.typed",
"chars": 0,
"preview": ""
},
{
"path": "src/lczero_training/tests/test_protobuf.py",
"chars": 2326,
"preview": "\"\"\"Test protobuf compilation and functionality.\"\"\"\n\n\ndef test_protobuf_import() -> None:\n \"\"\"Test that protobuf files"
},
{
"path": "src/lczero_training/tests/test_protocol_registry.py",
"chars": 3493,
"preview": "\"\"\"Test script for the protocol registry system.\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any\n\nimport py"
},
{
"path": "src/lczero_training/tests/test_weights_tool.py",
"chars": 8637,
"preview": "\"\"\"Test weights tool arithmetic operations.\"\"\"\n\nimport os\nimport tempfile\n\nimport numpy as np\n\nimport proto.net_pb2 as n"
},
{
"path": "src/lczero_training/tools/__init__.py",
"chars": 147,
"preview": "\"\"\"Pure Python tools for weight manipulation.\"\"\"\n\nfrom .weights_tool import load_weights, save_weights\n\n__all__ = [\"load"
},
{
"path": "src/lczero_training/tools/weight_codecs.py",
"chars": 3298,
"preview": "\"\"\"Encoding and decoding logic for Lc0 weight formats.\"\"\"\n\nimport numpy as np\n\nfrom proto import net_pb2\n\n\ndef decode_li"
},
{
"path": "src/lczero_training/tools/weight_wrappers.py",
"chars": 15561,
"preview": "\"\"\"Wrapper classes for pythonic access to Lc0 weight protobufs.\"\"\"\n\nimport gzip\nfrom typing import Any, Iterator\n\nimport"
},
{
"path": "src/lczero_training/tools/weights_tool.py",
"chars": 884,
"preview": "\"\"\"Main API for loading and saving Lc0 weight files.\"\"\"\n\nimport gzip\n\nfrom proto import net_pb2\n\nfrom .weight_wrappers i"
},
{
"path": "src/lczero_training/training/__init__.py",
"chars": 45,
"preview": "\"\"\"Training package for Leela Chess Zero.\"\"\"\n"
},
{
"path": "src/lczero_training/training/backfill_metrics.py",
"chars": 4587,
"preview": "\"\"\"Backfill metrics for existing checkpoints.\"\"\"\n\nimport logging\nfrom typing import Any\n\nfrom flax import nnx\nfrom googl"
},
{
"path": "src/lczero_training/training/dataloader_probe.py",
"chars": 3281,
"preview": "\"\"\"Utilities for exercising the training data loader.\"\"\"\n\nimport logging\nimport time\nfrom contextlib import suppress\nfro"
},
{
"path": "src/lczero_training/training/describe.py",
"chars": 2163,
"preview": "import logging\nimport sys\nfrom pathlib import PurePosixPath\n\nimport jax\nimport jax.numpy as jnp\nimport orbax.checkpoint "
},
{
"path": "src/lczero_training/training/eval.py",
"chars": 20526,
"preview": "# Description: Evaluation script for comparing model outputs and calculating losses.\n#\n# This script provides functional"
},
{
"path": "src/lczero_training/training/init.py",
"chars": 5591,
"preview": "import gzip\nimport logging\nimport os\nimport sys\nfrom typing import Optional\n\nimport orbax.checkpoint as ocp\nfrom flax im"
},
{
"path": "src/lczero_training/training/lr_schedule.py",
"chars": 5295,
"preview": "from typing import Callable, Sequence\n\nimport jax.numpy as jnp\nimport optax\n\nfrom proto.training_config_pb2 import LrSch"
},
{
"path": "src/lczero_training/training/migrate_checkpoint.py",
"chars": 11382,
"preview": "from typing import Any, Dict, Iterable, List, Set, Tuple\n\nimport jax\nimport numpy as np\nimport orbax.checkpoint as ocp\nf"
},
{
"path": "src/lczero_training/training/optimizer.py",
"chars": 2606,
"preview": "from functools import partial\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nfrom flax import nnx\n\nfrom lczero_trainin"
},
{
"path": "src/lczero_training/training/overfit.py",
"chars": 10061,
"preview": "\"\"\"Overfitting utility for quickly validating training setup.\"\"\"\n\nimport csv\nimport logging\nfrom contextlib import suppr"
},
{
"path": "src/lczero_training/training/state.py",
"chars": 5133,
"preview": "import dataclasses\nimport logging\nfrom typing import Any, Optional, Union\n\nimport jax\nimport jax.numpy as jnp\nimport jax"
},
{
"path": "src/lczero_training/training/tensorboard.py",
"chars": 1845,
"preview": "\"\"\"Utilities for writing training metrics to TensorBoard event files.\"\"\"\n\nfrom __future__ import annotations\n\nimport log"
},
{
"path": "src/lczero_training/training/test_lr_schedule.py",
"chars": 4107,
"preview": "from typing import Callable, List\n\nimport jax.numpy as jnp\nimport pytest\n\nfrom lczero_training.training.lr_schedule impo"
},
{
"path": "src/lczero_training/training/training.py",
"chars": 11325,
"preview": "import dataclasses\nimport logging\nfrom datetime import datetime\nfrom functools import partial\nfrom typing import Any, Ca"
},
{
"path": "src/lczero_training/training/tune_lr.py",
"chars": 9384,
"preview": "import csv\nimport logging\nimport sys\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import"
},
{
"path": "src/lczero_training/training/utils.py",
"chars": 600,
"preview": "from pathlib import PurePosixPath\n\nfrom flax import nnx\n\nfrom proto.training_config_pb2 import WeightsSelector\n\n\ndef mak"
},
{
"path": "src/lczero_training/tui/__init__.py",
"chars": 192,
"preview": "# ABOUTME: TUI package initialization for the training dashboard.\n# ABOUTME: Exports main TrainingTuiApp class for exter"
},
{
"path": "src/lczero_training/tui/app.py",
"chars": 8719,
"preview": "# ABOUTME: Main TUI application class implementing the training dashboard.\n# ABOUTME: Uses Textual framework to create a"
},
{
"path": "src/lczero_training/tui/app.tcss",
"chars": 3551,
"preview": "Screen {\n background: $primary;\n}\n\nHeaderBar {\n dock: top;\n height: 1;\n background: $primary-darken-1;\n c"
},
{
"path": "src/lczero_training/tui/data_pipeline_pane.py",
"chars": 7039,
"preview": "# ABOUTME: Data pipeline pane widget for displaying DataLoader metrics.\n# ABOUTME: Shows a grid of pipeline stages and q"
},
{
"path": "src/lczero_training/tui/dataloader_widgets.py",
"chars": 20070,
"preview": "\"\"\"Widgets that render data loader metrics without stage-specific logic.\"\"\"\n\nfrom __future__ import annotations\n\nfrom ty"
},
{
"path": "src/lczero_training/tui/log_pane.py",
"chars": 2158,
"preview": "import datetime\nfrom pathlib import Path\nfrom typing import Any, Optional, TextIO\n\nfrom anyio.streams.text import TextRe"
}
]
// ... and 19 more files (download for full content)
About this extraction
This page contains the full source code of the LeelaChessZero/lczero-training GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 219 files (1.1 MB), approximately 297.0k tokens, and a symbol index with 1253 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.